Weight normalization


Suppose that the incoming synaptic weights of a neuron are given as \(\mathbf{w}=w_1, w_2, \ldots, w_n\). A plasticity rule might require that the vector norm \(|\mathbf{w}|\) remains constant. For example, the L1-norm \(|\mathbf{w}|_1\) is used in [1], [2]:

\[|\mathbf{w}|_1 = \sum_i |w_i|\]

Keeping this norm constant at a desired target value, say \(w_{target}\), is typically done as an extra step after the main weights plasticity step (for example, after an STDP weight update). First, the norm is computed, and second, all weights \(w_1, \ldots, w_n\) are updated according to:

\[w_i \leftarrow w_{target} \frac{w_i}{|\mathbf{w}|_1}\]

Implementation in NEST

Because of the way that the data structures are arranged in NEST, normalizing the weights is a costly operation (in terms of time spent). One has to iterate over all the neurons, then for each neuron fetch all of its incoming connections, calculate the vector norm and perform the actual normalization, and finally to write back the new weights.

This would look something like:

def normalize_weights(neurons_to_be_normalized, w_target=1):
    for neuron in neurons_to_be_normalized:
        conn = nest.GetConnections(target=neuron)
        w = np.array(conn.weight)
        w_normed = w / sum(abs(w))  # L1-norm
        conn.weight = w_target * w_normed

To apply normalization only to a certain synapse type, GetConnections() can be restricted to return only synapses of that type by specifying the model name, for example GetConnections(..., synapse_model="stdp_synapse").

To be formally correct, weight normalization should be done at each simulation timestep, but weights typically evolve on a much longer timescale than the timestep that the network is simulated at, so this would be very inefficient. Depending on how fast your weights change, you may want to perform a weight normalization, say, every 100 ms of simulated time, or every 1 s (or even less frequently). The duration of this interval can be chosen based on how far the norm is allowed to drift from \(w_{target}\): longer intervals allow for more drift. The magnitude of the drift can be calculated at the end of each interval, by subtracting the norm from its target, before writing back the normed vector to the NEST connection objects.

To summarize, the basic strategy is to divide the total simulated time into intervals of, say, 100 ms. You simulate for 100 ms, then pause the simulation and normalize the weights (using the code above), and then continue simulating the next interval.