{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Postsynaptic trace computation\n", "==============================\n", "\n", "Pre- and postsynaptic traces are used to calculate STDP weight updates, but are computed differently: postsynaptic traces are stored and maintained in the NEST C++ `ArchivingNode` class. Following [nest-simulator#1034](https://github.com/nest/nest-simulator/issues/1034), this notebook (and corresponding unit test in `test_regression_issue-1034.py`) was created to specifically test the postsynaptic trace value, by comparing the NEST-obtained samples to a Python-generated reference timeseries.\n", "\n", "Construct a network of the form:\n", "- pre_spike_gen connects via static_synapse to pre_parrot\n", "- pre_parrot connects via stdp_synapse to post_parrot\n", "- post_spike_gen connects via static_synapse to post_parrot\n", "\n", "The spike times of the spike generators are defined in\n", "`pre_spike_times` and `post_spike_times`. From the perspective of the\n", "STDP synapse, spikes arrive with the following delays (with respect to\n", "the values in these lists):\n", "\n", "- for the presynaptic neuron: one synaptic delay in the static synapse\n", "- for the postsynaptic neuron: one synaptic delay in the static synapse\n", "- for the synapse itself: one dendritic delay between the post_parrot\n", " node and the synapse itself (see the C++ variable `dendritic_delay`)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as plticker\n", "import nest\n", "import numpy as np\n", "import os\n", "import scipy as sp\n", "import scipy.stats\n", "import unittest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "NEST simulation\n", "---------------\n", "\n", "Construct and run the NEST network." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def run_post_trace_test_nest_(pre_spike_times, post_spike_times,\n", " resolution, delay, sim_time, tau_minus,\n", " show_all_nest_trace_samples=False,\n", " debug=False):\n", "\n", " if debug:\n", " print(\"Pre spike times: [\"\n", " + \", \".join([str(t) for t in pre_spike_times]) + \"]\")\n", " print(\"Post spike times: [\"\n", " + \", \".join([str(t) for t in post_spike_times]) + \"]\")\n", "\n", " nest.set_verbosity(\"M_WARNING\")\n", "\n", " nest.ResetKernel()\n", " nest.resolution = resolution\n", "\n", " wr = nest.Create(\"weight_recorder\")\n", " nest.CopyModel(\"stdp_synapse\", \"stdp_synapse_rec\",\n", " {\"weight_recorder\": wr, \"weight\": 1.})\n", "\n", " # create spike_generators with these times\n", " pre_sg_ps = nest.Create(\"spike_generator\",\n", " params={\"spike_times\": pre_spike_times,\n", " \"precise_times\": True})\n", " post_sg_ps = nest.Create(\"spike_generator\",\n", " params={\"spike_times\": post_spike_times,\n", " \"precise_times\": True})\n", "\n", " # create parrot neurons and connect spike_generators\n", " pre_parrot_ps = nest.Create(\"parrot_neuron_ps\")\n", " post_parrot_ps = nest.Create(\"parrot_neuron_ps\",\n", " params={\"tau_minus\": tau_minus})\n", "\n", " nest.Connect(pre_sg_ps, pre_parrot_ps, syn_spec={\"delay\": delay})\n", " nest.Connect(post_sg_ps, post_parrot_ps,syn_spec={\"delay\": delay})\n", "\n", " # create spike recorder --- debugging only\n", " spikes = nest.Create(\"spike_recorder\")\n", " nest.Connect(pre_parrot_ps + post_parrot_ps, spikes)\n", "\n", " # connect both parrot neurons with a stdp synapse onto port 1\n", " # so spikes transmitted through the stdp connection are\n", " # not repeated postsynaptically.\n", " nest.Connect(\n", " pre_parrot_ps, post_parrot_ps,\n", " syn_spec={'synapse_model': 'stdp_synapse_rec',\n", " 'receptor_type': 1,\n", " 'delay': delay})\n", "\n", " if debug:\n", " print(\"[py] Total simulation time: \" + str(sim_time) + \" ms\")\n", "\n", " n_steps = int(np.ceil(sim_time / delay))\n", " trace_nest = []\n", " trace_nest_t = []\n", "\n", " t = nest.biological_time\n", " trace_nest_t.append(t)\n", "\n", " post_tr = post_parrot_ps.post_trace\n", " trace_nest.append(post_tr)\n", "\n", " for step in range(n_steps):\n", " if debug:\n", " print(\"\\n[py] simulating for \" + str(delay) + \" ms\")\n", " nest.Simulate(delay)\n", " t = nest.biological_time\n", " nearby_pre_spike = np.any(\n", " np.abs(t - np.array(pre_spike_times) - delay) < resolution/2.)\n", " if show_all_nest_trace_samples or nearby_pre_spike:\n", " trace_nest_t.append(t)\n", " post_tr = post_parrot_ps.post_trace\n", " trace_nest.append(post_tr)\n", " if debug:\n", " print(\"[py] Received NEST trace: \" +\n", " str(post_tr) + \" at time t = \" + str(t))\n", "\n", " return trace_nest_t, trace_nest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Python simulation\n", "-----------------\n", "\n", "Generate the Python reference timeseries." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def run_post_trace_test_python_reference_(pre_spike_times,\n", " post_spike_times, resolution,\n", " dendritic_delay, sim_time,\n", " tau_minus, debug=False):\n", " \"\"\"\n", " compute Python known-good reference of postsynaptic trace\n", " \"\"\"\n", "\n", " n_timepoints = 1000 * int(np.ceil(sim_time))\n", " trace_python_ref = np.zeros(n_timepoints)\n", "\n", " n_spikes = len(post_spike_times)\n", " for sp_idx in range(n_spikes):\n", " t_sp = post_spike_times[sp_idx] + 2 * dendritic_delay\n", " for i in range(n_timepoints):\n", " t = (i / float(n_timepoints - 1)) * sim_time\n", " if t > t_sp:\n", " trace_python_ref[i] += np.exp(-(t - t_sp) / tau_minus)\n", "\n", " n_spikes = len(pre_spike_times)\n", " for sp_idx in range(n_spikes):\n", " t_sp = pre_spike_times[sp_idx] + dendritic_delay\n", " i = int(np.round(t_sp / sim_time\n", " * float(len(trace_python_ref) - 1)))\n", " if debug:\n", " print(\"* At t_sp = \" + str(t_sp)\n", " + \", post_trace should be \" + str(trace_python_ref[i]))\n", "\n", " return trace_python_ref" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the test\n", "------------\n", "\n", "First, define some pre/post spike patterns." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "# spike test pattern: generate some random integer spike times\n", "t_sp_min = 1.\n", "t_sp_max = 50\n", "\n", "# pre spikes\n", "n_spikes = 10\n", "pre_spike_times = np.sort(\n", " np.unique(\n", " np.ceil(\n", " sp.stats.uniform.rvs(\n", " t_sp_min, t_sp_max - t_sp_min, n_spikes))))\n", "\n", "# post spikes\n", "n_spikes = 50\n", "post_spike_times = np.sort(\n", " np.unique(\n", " np.ceil(\n", " sp.stats.uniform.rvs(\n", " t_sp_min, t_sp_max - t_sp_min, n_spikes))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define a function that will validate equality between the Python-generated and the NEST-generated timeseries." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "trace_match_atol = 1E-3\n", "trace_match_rtol = 1E-3\n", "\n", "def nest_trace_matches_ref_trace(trace_nest_t, trace_nest,\n", " trace_python_ref, pre_spike_times,\n", " post_spike_times, resolution,\n", " dendritic_delay, trace_match_atol,\n", " trace_match_rtol, sim_time,\n", " debug=False):\n", " \"\"\"\n", " Trace values are returned from NEST at regular intervals, but only\n", " updated at presynaptic spike times.\n", "\n", " To match the NEST samples with the continuous reference trace, step\n", " backwards in time from the sampled value, to find the last time at\n", " which the trace value was updated, namely the time of occurrence of\n", " the last presynaptic spike.\n", " \"\"\"\n", "\n", " n_timepoints = len(trace_nest_t)\n", " for i in range(n_timepoints)[1:]:\n", " t = trace_nest_t[i]\n", " if debug:\n", " print(\"* Finding ref for NEST timepoint t = \" + str(t)\n", " + \", NEST trace = \" + str(trace_nest[i]))\n", "\n", " traces_match = False\n", " for i_search, t_search in enumerate(\n", " reversed(np.array(pre_spike_times) + dendritic_delay)):\n", " if t_search <= t:\n", " _trace_at_t_search = trace_python_ref[int(np.round(\n", " t_search / sim_time\n", " * float(len(trace_python_ref) - 1)))]\n", " traces_match = np.allclose(\n", " _trace_at_t_search,\n", " trace_nest[i],\n", " atol=trace_match_atol,\n", " rtol=trace_match_rtol)\n", " post_spike_occurred_at_t_search = np.any(\n", " (t_search - (np.array(post_spike_times)\n", " + 2 * dendritic_delay))**2\n", " < resolution/2.)\n", "\n", " if debug:\n", " print(\"\\t* Testing \" + str(t_search) + \"...\")\n", " print(\"\\t traces_match = \" + str(traces_match))\n", " print(\"\\t post_spike_occurred_at_t_search = \"\n", " + str(post_spike_occurred_at_t_search))\n", "\n", " if (not traces_match) and post_spike_occurred_at_t_search:\n", " traces_match = np.allclose(\n", " _trace_at_t_search + 1,\n", " trace_nest[i],\n", " atol=trace_match_atol,\n", " rtol=trace_match_rtol)\n", " if debug:\n", " print(\"\\t traces_match = \" + str(traces_match)\n", " + \" (nest trace = \" + str(trace_nest[i])\n", " + \", ref trace = \"\n", " + str(_trace_at_t_search + 1)\n", " + \")\")\n", " if traces_match:\n", " _trace_at_t_search += 1.\n", "\n", " if (not traces_match) and post_spike_occurred_at_t_search:\n", " traces_match = np.allclose(\n", " _trace_at_t_search - 1,\n", " trace_nest[i],\n", " atol=trace_match_atol,\n", " rtol=trace_match_rtol)\n", " if debug:\n", " print(\"\\t traces_match = \" + str(traces_match)\n", " + \" (nest trace = \" + str(trace_nest[i])\n", " + \", ref trace = \"\n", " + str(_trace_at_t_search - 1)\n", " + \")\")\n", " if traces_match:\n", " _trace_at_t_search -= 1.\n", "\n", " break\n", "\n", " if (not traces_match) and i_search == len(pre_spike_times) - 1:\n", " if debug:\n", " print(\"\\tthe time before the first pre spike\")\n", " # the time before the first pre spike\n", " traces_match = trace_nest[i] == 0.\n", "\n", " if not traces_match:\n", " return False\n", "\n", " return True\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plotting function:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def plot_run(trace_nest_t, trace_nest, trace_python_ref,\n", " pre_spike_times, post_spike_times, resolution,\n", " dendritic_delay, trace_match_atol, trace_match_rtol,\n", " sim_time, title_snip=\"\", debug=False):\n", "\n", " fig, ax = plt.subplots(nrows=3, dpi=120)\n", " ax1, ax2, ax3 = ax\n", "\n", " #\n", " # pre spikes\n", " #\n", "\n", " ax1.set_ylim([0., 1.])\n", " ax1.set_ylabel(\"Pre spikes\")\n", " n_spikes = len(pre_spike_times)\n", " for i in range(n_spikes):\n", " ax1.plot(2 * [pre_spike_times[i] + dendritic_delay],\n", " ax1.get_ylim(),\n", " linewidth=2, color=\"blue\", alpha=.4)\n", "\n", " #\n", " # post spikes\n", " #\n", "\n", " ax2.set_ylim([0., 1.])\n", " ax2.set_ylabel(\"Post spikes\")\n", " n_spikes = len(post_spike_times)\n", " for i in range(n_spikes):\n", " ax2.plot(2 * [post_spike_times[i] + 2 * dendritic_delay],\n", " [0, 1],\n", " linewidth=2, color=\"red\", alpha=.4)\n", "\n", " #\n", " # traces\n", " #\n", "\n", " ax3.set_ylabel(\"Synaptic trace\")\n", " ax3.set_ylim([0., np.amax(trace_python_ref)])\n", " ax3.plot(np.linspace(0., sim_time, len(trace_python_ref)),\n", " trace_python_ref,\n", " label=\"Expected\", color=\"cyan\", alpha=.6)\n", " ax3.scatter(trace_nest_t, trace_nest,\n", " marker=\".\", alpha=.5, color=\"orange\", label=\"NEST\")\n", " ax3.legend()\n", "\n", " #\n", " # Trace values are returned from NEST at regular intervals, but only\n", " # updated at presynaptic spike times.\n", " #\n", " # Step backwards in time from the sampled value, to find the last\n", " # time at which the trace value was updated, namely the time of\n", " # occurrence of the last presynaptic spike.\n", " #\n", "\n", " pre_spike_times = np.array(pre_spike_times)\n", " n_timepoints = len(trace_nest_t)\n", " for i in range(n_timepoints):\n", " t = trace_nest_t[i]\n", " if debug:\n", " print(\"* Finding ref for NEST timepoint t = \"\n", " + str(t) + \", trace = \" + str(trace_nest[i]))\n", " for t_search in reversed(pre_spike_times + dendritic_delay):\n", " if t_search <= t:\n", " if debug:\n", " print(\"\\t* Testing \" + str(t_search) + \"...\")\n", " _idx = int(np.round(t_search / sim_time\n", " * float(len(trace_python_ref) - 1)))\n", " _trace_at_t_search = trace_python_ref[_idx]\n", " traces_match = np.allclose(_trace_at_t_search,\n", " trace_nest[i],\n", " atol=trace_match_atol,\n", " rtol=trace_match_rtol)\n", " if debug:\n", " print(\"\\t traces_match = \" + str(traces_match))\n", " if not traces_match:\n", " post_spike_occurred_at_t_search = np.any(\n", " (t_search - (np.array(post_spike_times)\n", " + 2 * dendritic_delay))**2 < resolution/2.)\n", " if debug:\n", " print(\"\\t post_spike_occurred_at_t_search = \"\n", " + str(post_spike_occurred_at_t_search))\n", " if post_spike_occurred_at_t_search:\n", " traces_match = np.allclose(\n", " _trace_at_t_search + 1,\n", " trace_nest[i],\n", " atol=trace_match_atol,\n", " rtol=trace_match_rtol)\n", " if debug:\n", " print(\"\\t traces_match = \" + str(traces_match)\n", " + \" (nest trace = \" + str(trace_nest[i])\n", " + \", ref trace = \"\n", " + str(_trace_at_t_search+1) + \")\")\n", " \n", " if traces_match:\n", " _trace_at_t_search += 1.\n", "\n", " if not traces_match:\n", " traces_match = np.allclose(\n", " _trace_at_t_search - 1,\n", " trace_nest[i],\n", " atol=trace_match_atol,\n", " rtol=trace_match_rtol)\n", " \n", " if debug:\n", " print(\"\\t traces_match = \"\n", " + str(traces_match)\n", " + \" (nest trace = \"\n", " + str(trace_nest[i])\n", " + \", ref trace = \"\n", " + str(_trace_at_t_search-1) + \")\")\n", " \n", " if traces_match:\n", " _trace_at_t_search -= 1.\n", "\n", " ax3.scatter(t_search, _trace_at_t_search, 100, marker=\".\",\n", " color=\"#A7FF00FF\", facecolor=\"none\")\n", " ax3.plot([trace_nest_t[i], t_search],\n", " [trace_nest[i], _trace_at_t_search],\n", " linewidth=.5, color=\"#0000007F\")\n", " break\n", "\n", " for _ax in ax:\n", " _ax.xaxis.set_major_locator(\n", " plticker.MultipleLocator(base=10*dendritic_delay))\n", " _ax.xaxis.set_minor_locator(\n", " plticker.MultipleLocator(base=dendritic_delay))\n", " _ax.grid(which=\"major\", axis=\"both\")\n", " _ax.grid(which=\"minor\", axis=\"x\", linestyle=\":\", alpha=.4)\n", " _ax.set_xlim(0., sim_time)\n", "\n", " ax3.set_xlabel(\"Time [ms]\")\n", " fig.suptitle(title_snip)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, run the test and make the plots while we go.\n", "\n", "The plots should be interpreted as follows. Pre- and postsynaptic spikes are shown in the top two subplots, at the time at which they arrive at the synapse (i.e. from the perspective of the synapse, taking dendritic and axonal delays into account).\n", "\n", "The bottom subplot shows the reference/known-good timeseries generated (numerically) in Python (**cyan colour**). The values returned from NEST are shown using **orange circles**. They are plotted as points rather than as a continuous line, because we can only retrieve the value at the resolution of the minimum synaptic delay (i.e. fetch trace value; simulate for a timestep `delay`; repeat). Moreover, the postsynaptic trace value is only updated in NEST during the processing of a presynaptic spike, so unless a presynaptic spike was processed in the last delay interval, the value will remain unchanged. To allow comparison between the Python- and NEST-generated values, we thus search for the previous time at which NEST would have updated the trace value, which is the time of arrival of the last presynaptic spike. This value is marked by an **open green circle**. If all is well, all green circles should always overlap an orange circle, and all **black lines** (which simply connect subsequent postsynaptic trace values returned by NEST) should be perfectly horizontal." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "resolution = .1 # [ms]\n", "dendritic_delay = 1. # [ms]\n", "tau_minus = 2. # [ms]\n", "\n", "# settings for plotting debug information\n", "show_all_nest_trace_samples = True\n", "\n", "max_t_sp = max(np.amax(pre_spike_times),\n", " np.amax(post_spike_times))\n", "sim_time = max_t_sp + 5 * dendritic_delay\n", "trace_nest_t, trace_nest = run_post_trace_test_nest_(\n", " pre_spike_times,\n", " post_spike_times,\n", " resolution, dendritic_delay, sim_time, tau_minus,\n", " show_all_nest_trace_samples)\n", "trace_python_ref = run_post_trace_test_python_reference_(\n", " pre_spike_times,\n", " post_spike_times,\n", " resolution, dendritic_delay, sim_time, tau_minus)\n", "\n", "title_snip = \"Dendritic delay = \" + str(dendritic_delay)\n", "plot_run(\n", " trace_nest_t, trace_nest, trace_python_ref,\n", " pre_spike_times,\n", " post_spike_times, resolution,\n", " dendritic_delay, trace_match_atol, trace_match_rtol,\n", " sim_time, title_snip)\n", "assert nest_trace_matches_ref_trace(\n", " trace_nest_t,\n", " trace_nest,\n", " trace_python_ref,\n", " pre_spike_times,\n", " post_spike_times,\n", " resolution, dendritic_delay,\n", " trace_match_atol,\n", " trace_match_rtol,\n", " sim_time,\n", " debug=False)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "License\n", "-------\n", "\n", "This file is part of NEST. Copyright (C) 2004 The NEST Initiative\n", "\n", "NEST is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 2 of the License, or (at your option) any later version.\n", "\n", "NEST is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 4 }