The Case of Spiking Neural Networks
Traditional artificial neural networks have attained incredible, even at times super-human, performance on specific well-defined tasks like object detection, recognition, and sequential decision-making. But whether the neurons in our human brains work the same way is questionable. It is generally believed that our neural architectures are not feed-forward, but contain many more loops, with many feedback effects on all levels. Additionally, biological neurons have the temporal dimension naturally built into their workings, whereas the perceptron has no notion of time and cannot distinguish between different orderings of the same sequence of inputs. This post explores spiking neural networks, which address the above concerns and seem to more closely mimic the neurons in the real world.
To derive the behaviour of a neuron we can construct an electric circuit from the observed neuron's morphology and solve that circuit. Figure 1 shows a simplified neuron with which we can compare.
T The neuron is a cell specialized in the transmission of electric pulses. It has a cell body, called soma, which contains the nucleus where most protein synthesis processes occur. There are also multiple dendrites - extensions of the neuron which bring in signals from the other neurons into the body, and an axon which takes information away from the body. The dendrites are covered with synapses which are the specific contact points where the transfer of signals, i.e. electric pulses, between neurons occurs.
To make an analogy (that may sacrifice correctness in favour of simplicity) with artificial neurons, we can think of it as follows. The dendrites determine how many inputs there are. The synapses represent the actual learnable weights \(w_i\). The sign and magnitude of the weights represents the excitatory or inhibitory effect they have on the output. The role of the soma is to aggregate the inputs \(x_1, ..., x_N\) by summing up all the modulated signals - outputs from other neurons times weights - and the axon, along with the signal travelling in it \(y\), represent the output. The non-linearity function \(\sigma\) does not fit particularly well in this picture but we can agree its nature is less biologically-inspired and more of a practical necessity. This forms the standard formula for the perceptron
To improve over this model we can take a closer look at another component - the membrane surrounding the soma. It is a bilayer of lipid molecules that separates the solutions within and outside of the neuron body. In particular, the membrane acts as a insulator between the two conductive solutions and can thus be modelled as a capacitor. The resulting voltage is measured between the inside and the outside of the membrane. Moreover, the membrane also has ion channels which are pathways that allow various ions (of sodium, chloride, etc.) to flow through it. These are voltage-gated and can be opened or closed depending on the voltage difference across the membrane. As a result, we can model these ion channels as resistors in a circuit.
With our circuit defined, we can solve it. The input current splits across the capacitor and the resistor \(I(t) = I_R(t) + I_C(t)\). For a capacitor, the total charge \(Q(t)\) is equal to the product of the capacitance \(C\) and the current voltage \(U(t)\), \(Q(t) = C U(t)\). The total charge changes as \(\frac{d Q(t)}{d t} = C \frac{d U(t)}{dt} = I_C(t)\). For the resistor, Ohm's law states that if \(U(t)\) is the voltage between two points and \(R\) is the resistance, then the current is \(U(t)/R\). Hence,
Solving this yields
This is a very useful controller. The equation above shows that the membrane potential \(U(t)\) depends on the input current \(I(t)\). If we supply the neuron membrane with a constant current, then the potential \(U(t)\) will change exponentially fast towards the value \(I(t) R\) and will stabilize, because \(I(t)\) is constant. If on the other hand, \(I(t)\) is rapidly changing, then \(U(t)\) will also change rapidly, trying to catch up to \(I(t)R\). If \(I(t)\) is zero, \(U(t)\) will tend to 0. If \(I(t)\) is oscillating, \(U(t)\) will also oscillate. And if \(I(t)\) is a pulse input, then \(U(t)\) will spike up initially and then quickly decay towards 0.
If the input current is spiky, with only periodically being different from zero, and for very short amounts of time, then the membrane potential also starts to look spiky through time. In an infinitesimally short period \(dt\) during which \(I(t)\) is non-zero, the rise in \(U(t)\) is almost infinitesimally steep, owing to the rapid change in \(I(t)\), after which it starts to decay gradually.
So far, the neuron has an adaptive membrane potential but no real output. It is common to model the output as a binary spike which is toggled if the membrane potential reaches a predefined threshold. Additionally, it has been observed that neurons do not fire constantly if their resting potential \(I(t) R\) is above the threshold. This requires adding a cooldown period, or a reset mechanism, which decrements the membrane potential before the neuron can fire again. This is called hyperpolarization.
Thus, let the membrane potential threshold be \(\bar{U}\) and the spike variable be \(S(t)\) where
To illustrate the whole idea we discretize the differential equation above and implement the spikes and the reset mechanism:
The reset mechanism instantly decrements the current potential by the spiking threshold if a spike has occurred. This is one way to model it and it's also possible to directly set the potential to zero. The combination of the simple RC-circuit, the spikes threshold and the reset mechanism is called the leaky-integrate-and-fire model. It was proposed by Louis Lapicque back in 1907. Since then, there have been other models like the adaptive integrate-and-fire where the frequency of spikes decreases even with constant input current, and the Hodgkin–Huxley model which has additional components.
At this point, we can start using the Lapicque neuron in algorithms. However, it's useful to simplify it a bit. The actual functional form used in practice typically assumes that \(R=1\) and is
where \(\beta\) is the decay rate controlling how fast the membrane potential decays. The input current \(I(t)\) is also replaced with a learnable weight \(w x(t)\) where \(x(t)\) is a binary, indicating spike or no spike in the data. We can now stack multiple such neurons in layers. Suppose the input \(X\) has shape \((T, B, D)\) with \(T\) time steps, \(B\) samples in the batch, and \(D\) dimensions in each sample in each time point. We initialize all the membranes as \(U^{(0)}\). \(X^{(0)}\) is the data in the first timestep with shape \((B, D)\) which goes through the first layer, outputting \(y^1\) and \(U^{(1)}\). Both of these go into the second layer, outputting \(y^{(2)}\) and \(U^{(2)}\) and so on. \(X^{(1)}\) is the data from the next timestep which is processed in a similar way.
When it comes to training, this setup is amenable to backpropagation through time - the same variation of backpropagation used to train RNNs. Here, we're interested in computing the gradient of the loss with respect to the weights
where \(S\) is the spike, \(U\) is the membrane potential, \(I\) is the input current, and \(W\) are the weights. The problem, however, is that \(\partial S/\partial U\) is non-differentiable at \(\bar{U}\) and has derivative equal to zero everywhere else. This is sometimes called the dead neuron problem since no actual updates can take place.
Note that if there was only a discontinuity, a method specialized for handling it, like the subgradient, could be used. Here however, the spike is essentially given by the Heaviside step function, shiften at \(\bar{U}\), and the derivative with respect to \(U\) is zero almost everywhere. To fix this problem we can approximate \(\partial S / \partial{U}\) by replacing it with a more well-behaved surrogated gradient term.
There are a few good choices here. One is to simply replace \(\partial S(t) / \partial{U}\) with \(S(t)\). As a result, neurons which fire at time \(t\) will have their gradients backpropagated because \(S(t) = 1\). Those which don't fire won't backpropagate. Implementing this in practice requires saving the value of \(S(t)\) during the forward pass, and using it during the backward pass.
Another common approach is to smooth the Heaviside function during backprop, for example using the fast sigmoid:
In that case during the forward pass we compute the spike normally using the Heaviside function but save \(U\) so that we're able to compute \(\partial S(t) / \partial{U}\) during backpropagation. Overall, replacing the actual gradient term with something else that is easily computable is strange. But it has become common practice, and the resulting spiking neural networks do seem to be fairly robust to these kinds of changes.
To be able to train, we also need to define the loss function and how the output spikes will be interpreted. For a single sample at test time, the usual expected shape is \((T, D)\) and in a classification task with \(C\) classes the output will have shape \((T, C)\). The target is of shape \((T, 1)\). One way to determine the prediction is to count the number of spikes in each class across time. To fire the most, we incentivize the correct neuron to have high average membrane potential. This can all be achieved by softmax-ing the membrane potentials at a given step \(t\) and computing the cross-entropy loss:
That is, as in a typical RNN scenario, the final loss is the mean sequence cross-entropy across all training samples \((X, Y)\) from the dataset distribution \(\mathcal{D}\). The loss on each sequence is sum of the cross entropy terms across all time steps. The cross-entropy for a single time step is computed from the normalized membrane potentials and the one-hot encoded target vector \(\textbf{y}(t)\).
The simple leaky-integrate-and-fire is not the only neuron model that is commonly used in practice. Two variations are worth mentioning here. The first variation is concerned with modelling the synaptic current more accurately. The Lapicque model described above assumes that a spike in the pre-synaptic neuron causes an instantaneous increase in the current passing through the synapse. There are experiments showing that this is not the case and in reality this current increases gradually. We can model this like
That is, the spike in the pre-synaptic neuron \(x(t)\) causes a gradual, as opposed to instantaneous, increase in the synaptic current \(I_{syn}\). This is passed to the post-synaptic neuron whose membrane potential \(U(t+1)\) is updated using the new synaptic current. In cases where long-term sparse signals need to be modelled this kind of synaptic model is better suited, compared to the standard Lapicque model. That being said, there is really no concrete evidence that this comparison holds strongly.
The second variation is based on the so-called alpha neurons which model the membrane potential as a filter applied to the input spike \( U(t + 1) = \sum_i w (\epsilon \ast S(t))\). Here \(w\) is a weight, \(\epsilon\) is a general filter \(\ast\) is the convolution operation, and \(S(t)\) is the previous spike. Depending on the filter, many different effects can be introduced, including delayed responses, threshold adaptation, and so on.
Overall, spiking neural networks provide a more realistic, biologically-inspired form of neural computation. Their main selling points are said to be the easy representation of binary spikes in hardware (compared to high precision floats), and their sparsity, which allows cheaper storage and faster computation. It is also more energy-efficient since human sensors process the world in an asynchronous manner, only when a change in the sensor inputs is detected. This leads to lots of avoided processing for those sensors whose inputs don't change. Since spikes suppress any signals that are not strong enough, this makes spiking neural nets particularly well-suited to neuromorphic datasets where we have asynchronous events, along with their timestamps. The downside is that I don't think there has been any clear scenario where spiking neural nets seem to yield strictly better performance than their standard artificial counterparts. It'll be interesting to see how this comparison evolves in the following years.