S1 Text. Mathematical details. Derivation of ELBO and Inference. 1 Evidence Lower Bound (ELBO). Here we derive the evidence lower bound (ELBO) used as a variational objective by our inference .... support on the integers 1 ...D: pk(d|z = j) ...
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses
S1 Text. Mathematical details Derivation of ELBO and Inference
1
Evidence Lower Bound (ELBO)
Here we derive the evidence lower bound (ELBO) used as a variational objective by our inference algorithm. That is, we want to calculate p(Θ|N ) L ≡ Eq log = Eq [log p(Θ|N )] + H[q(Θ)] (1) q(Θ) From [1], this can be written p(A) p(π) p(N, z|λ, A, π) + Eq(A) log + Eq log L = Eq(π) log q(π) q(A) q(z) p(θ) p(λ) p(γ) + Eq(θ) log + Eq(λ) log + Eq(γ) log q(θ) q(λ) q(γ)
(2)
For the first two terms, updates are standard and covered in [1]. The rest we do piece-by-piece below:
1.1
Log evidence
i h We would like to calculate Eq log p(N,z|x,Θ) . To do this, we make use of expectations q(z) calculated via the posteriors returned from the forward-backward algorithm ξt ≡ p(zt |N, θ)
Ξt,ij ≡ p(zt+1 = j, zt = i|N, θ)
log Zt = log p(Nt+1 |Nt , Θ)
(3)
Here, we have suppressed the latent feature index k and abuse notation by writing the observation index as t, but in the caseP of multiple observations at a given time, we pool across units and presentations: Nt ≡ m;t(m)=t Nm . From this, we can write Eq [log p(N, z|x, Θ)] =
X
Nm log θm + log λ0u(m) + ξt(m)k log λzuk + xt(m)r log λxu(m)r
mkr
−
X m
i Xh T θm Eq Λt(m)u(m) + tr Ξtk log ATk + ξ0k log πk + constant
(4)
tk
where for notational simplicity, we have replaced expectations with overlines: x ≡ E[x]. In what follows, we will drop the irrelevant constant. For log y, where y ∈ {θ, λ0 , λz , λx }, the assumption q(y) = Ga(α, β) gives log y = ψ(α) − log β
(5)
with ψ(x) the digamma function. Likewise, the expectation θ is straightforward. For the expectation of the rate, we have " # Y Y α0u Y αzuk Y 1 Γ(αxur + xtr ) ztk xtr Eq λ0u (λzuk ) (λxur ) = 1 − ξtk + ξtk xtr β0u βzuk βxur Γ(αxur ) r r k k (6)
PLOS
1/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses However, for α x, we have Γ(α + x)/Γ(α) ≈ αx , so that we can write Eq [Λtu ] = H0u Ftu Gtu Q with Gtu ≈ r (αxur /βxur )xtr . In addition, it will later be useful to have the expectation over all except a particular feature k or r, for which we define Y αzuk0 Ftuk ≡ 1 − ξtk0 + ξtk0 βzuk0 k0 6=k x Y αxur0 tr0 Gtur ≡ βxur0 0
(7)
(8) (9)
r 6=r
Finally, we want the entropy of the variational posterior over z, Eq [− log q(z)]. We can write this in the form i Xh T T − ξtk ηtk + tr Ξtk A˜Tk + ξ0k π ˜k − log Ztk (10) tk
˜ π with (η, A, ˜ ) the parameters of the variational posterior corresponding to the emission, transition, and initial state probabilities of the Markov chain (interpreted as matrices) and Z the normalization constant. From [1], we have that variational updates should give A˜k ← log Ak
(11)
π ˜k ← log πk
(12)
while the effective emission probabilities in the “on” (z = 1) state of the HMM are X X θm H0u(m) Ftku(m) Gtu(m) (13) ηtk ← δztk ,1 Nm log λzu(m)k − m;t(m)=t
m;t(m)=t
˜ π Given these update rules, we can then alternate between calculating (η, A, ˜ ), ˜ performing forward-backward to get (ξ, Ξ, log Z) and recalculating (η, A, π ˜ ).
1.2
Overdispersion, firing rate effects
Both the case of p(θ) and p(λ) are straightforward. If we ignore subscripts and write p(y) = Ga(a, b), q(y) = Ga(α, β), then p(y) Eq log = (a − 1)log y + by + H[q(y)] (14) q(y) where again, log y, y and H[q(y)] are straightforward properties of the Gamma distribution. Expectations of the prior parameters are listed in Table 1. Note in the last line that there is no expectation, since we have not assumed a hierarchy over firing rate effects for the covariates, x.
1.3
Hyperparameters
As shown in the main text, the hyperparameters c and d, given gamma priors, have conjugate gamma posteriors, so that their contribution to the evidence lower bound, h i Eq log p(γ) is a sum of terms of the form q(γ) (ac − 1)log c + bc c + H[q(c)] + (ad − 1)log d + bd d + H[q(d)]
PLOS
(15)
2/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses Table 1. Expectations of prior parameters for overdispersion and firing rates Variable θ λ0 λz λx
Eq [a] s c0 cz ax
Eq [b] s c0 d 0 cz d z bx
In other words, these are straightforward gamma expectations, functions of the prior parameters a and b for each variable and the corresponding posterior parameters α and β. Similarly, the overdispersion terms are exactly the same with the substitutions c → s, d → 1. As we will see below, the expectations under the variational posterior of s, c, and d are themselves straightforward to calculate.
1.4
Autocorrelated noise
In the main text, we assumed the θm to be uncorrelated. However, it is possible to model temporal autocorrelation among the θm when observations correspond to the same neuron at successive time points. More specifically, let us replace the observation index m by τ , the experimental clock time. We then write a particular observation as corresponding to a stimulus time t(τ ) and a set of units u(τ ), which, for simplicity, we will assume fixed and simply write as u.1 To model the autocorrelation of noise across successive times, we then write Nm ∼ Pois(Λt(τ ),u θτ u )
Assuming θτ u = φτ u θτ −1,u
(16)
Q If φ0u = θ0u , we then have θτ u = τ 0 ≤τ φτ 0 u . In essence, this is a log-autoregressive process in which the innovations are not necessarily normally distributed. In fact, if we further assume that p(φτ u ) = Ga(su , ru ) and q(φτ u ) = Ga(ωτ u , ζτ u ), we can once again make use of conjugate updates. Given these assumptions, we need to make the following modifications for the third and fifth terms in the evidence lower bound of Eq (4): X X X X Nm ξt(m)k log λzuk → Nm ξt(m)k log λzuk + log φτ u Nτ 0 u (17) mkr
−
X m
τu
mkr
τ 0 ≥τ
X Y ωτ 0 u θm Eq Λt(m)u(m) → − H0u Fτ ku Gτ u ζτ 0 u 0 τu
(18)
τ ≤τ
Thus the effective emission probabilities in State 1 now become: X Y ωτ 0 u X ηtk → δztk ,1 Nm log λzuk − H0u Fτ ku Gτ u ζτ 0 u 0 m u;t(τ )=t
1.5
(19)
τ ≤τ
Latent states, semi-Markov dynamics
We model each of the latent states ztk as an independent Markov process for each feature k. That is, each k indexes an independent Markov chain with initial state probability πk ∼ Dir(απ ) and transition matrix Ak ∼ Dir(αA ). For the semi-Markov case, we assume that the dwell times in each state are distributed independently for 1 The generalization to partially overlapping neurons and stimuli is straightforward but complex and notationally cumbersome.
PLOS
3/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses each chain according to an integer-valued, truncated lognormal distribution with support on the integers 1 . . . D: pk (d|z = j) = Log-Normal(d|mjk , s2jk )/Wjk Wjk =
D X
(20)
Log-Normal(d|mjk , s2jk )
(21)
d=1
Note that we have allowed the dwell time distribution to depend on both the feature k and the state of the Markov chain j. In addition, we put independent Normal-Gamma priors on the mean (mkj ) and precision (τkj ≡ s−2 kj ) parameters of the distribution: (m, τ ) ∼ NG(µ, λ, α, β). In this case, we additionally need to perform inference on the parameters (m, τ ) of the dwell time distributions for each state. In the case of continuous dwell times, our model in Equation 20 would have W = 1 and be conjugate to the Normal-Gamma prior on (m, τ ), but the restriction to discrete dwell times requires us to again lower bound the variational objective: " !# D D X X Eq [− log Wjk ] = Eq − log p(d|j) ≥ − log Eq [p(d|j)] (22) d=1
d=1
This correction for truncation must then be added to Eq [p(z|Θ)]. For inference in the semi-Markov case, we use an extension of the forward-backward algorithm [2], as well as the algorithms in [3, 4], at the expense of computational complexity O(SDT ) (S = 2) per latent state, to calculate q(zk ). For the 4SK hyperparameters of the Normal-Gamma distribution, we perform an explicit BFGS optimization on the 4S parameters of each chain on each iteration (detailed in Subsection 2.3). The new ELBO for the semi-Markov model thus adds to the terms involving Ξ and ξ0 in (4) an additional piece: " !# D X X C(d, j, k)Eq log pk (d|j) − log pk (d|j) ≥ djk
d=1
" X
C(d, j, k) Eq [log pk (d|j)] − log
djk
D X
# Eq [pk (d|j)]
(23)
d=1
where as noted above, the second term inside the expectation arises from the need to normalize p(d|j) over discrete times and the inequality follows from Jensen’s Inequality. Here, d is the duration variable, j labels states of the chain (here 0 and 1), and k, as elsewhere, labels chains [2–4]. C is defined for each state and chain as the probability of a dwell time d in state j conditioned on the event of just having transitioned to j 2 3 Thus the objective to be optimized is " # D X X p(m, τ ) C(d, j, k) Eq [log pk (d|j)] − log Eq [pk (d|j)] + Eq log (25) q(m, τ ) djk
d=1
2 Thus
C is equivalent to Dt|T in [2]. We also note that, just as the emission, transition, and initial state probabilities have their P ˜ π counterparts in variational parameters (η, A, ˜ ) in q(z), so C is matched by a term − djk C(d, j, k)νdjk in Eq [− log q]. And in analogy with the HMM case, variation with respect to ν gives 3
νdjk = Eq [log pk (d|j)] − log
D X
Eq [pk (d|j)] ,
(24)
d=1
implying that there are cancellations between log p(N, z|Θ) and log q(z) in L and care must be taken when calculating it [1].
PLOS
4/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses We focus on each of these and their updates below. 1.5.1
Contributions to the ELBO
For the case of p(d|m, s2 ) Log-Normal and (m, τ = s−2 ) Normal-Gamma, the terms Eq [log p(d|j)] involve only routine expectations of natural parameters of the Normal-Gamma, and similarly for the expected log prior and entropy in the last term. Writing the delay distribution in the exponential family form with base measure h and natural parameter η p(d|z = i) = h(d) exp(ηi · T (d) − A(ηi ))
(26)
yields for the lognormal distribution log d ∼ N (m, s2 ) 1 h(d) = √ 2πd log d T (d) = (log d)2 m s2 η= − 2s12 A(η) =
m2 + log s 2s2
(27) (28) (29) (30)
If we then put a Normal-Gamma posterior on (m, τ ) with parameters (µ, λ, α, β), we can easily calculate α µβ mτ Eq [η] = Eq = (31) α − 2β − 12 τ 1 1 2 m τ − log τ (32) Eq [A(η)] = Eq 2 2 1 1 1 = Eτ (33) µ2 + τ − (ψ(α) − log β) 2 λτ 2 1 1 1 α µ2 + − (ψ(α) − log β) (34) = 2 β λ 2 The relevant piece of the evidence lower bound for updating (µ, λ, α, β) is X p(m, τ ) C(d, i)[log h(d) + Eq [ηi ] · T (d) − Eq [A(ηi )] + Eq log q(m, τ )
(35)
i,d
again, with q(m, τ ) = Normal-Gamma(µ, λ, α, β). We can thus write 1 1 log 2πC0 − C1 + Eq [mτ ]C1 − Eq [τ ]C2 2 2 1 1 2 − C0 Eq [m τ ] + C0 Eq [log τ ] 2 2 = G + F · T (α, β, λ, µ)
Eq [log p(d|j)] = −
PLOS
(36)
5/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses where 1 G ≡ − log 2πC0 − C1 2 F ·T =
C0 ≡
C
0
2
X
− C22
C1
(37) log τ τ − C20 · Eq mτ m2 τ
(38)
Cid
(39)
Cid log d
(40)
Cid (log d)2
(41)
id
C1 ≡
X id
C2 ≡
X id
For calculating the final term in (25), we can make use of conjugacy (since p(m, τ |α0 , β0 , λ0 , µ0 ) is itself Normal-Gamma) and the fact that (36) only depends linearly on the expected sufficient statistics of q to write the contribution of Eq [log p(m, τ )] as an update to F : C0 2 + α0 2 − C2 − β − λ0 µ0 0 2 2 (42) F0 = C1 + λ0 µ0 C0 λ0 − 2 − 2 Moreover, we can make use of the standard entropy formula for the Normal-Gamma distribution: H[q] = Eq [− log q] = Eq [− log q(τ ) − log q(m|τ )] " # r λτ λτ 2 = Hg (α, β) + Eq − log + (m − µ) 2π 2 = α − log β + log Γ(α) + (1 − α)ψ(α) λ 1 1 1 − (ψ(α) − log β) + − log 2 2π 2 2
(43)
Only slightly more complicated is the term arising from the normalization constant, which in the standard (µ, λ, α, β) parameterization of the Normal-Gamma takes the form r Z 1 1 λ Γ(α + 1/2) β − 2 Eq [p(d|m, τ )] = dτ dm p(d|i, m, τ )q(m, τ ) = √ (44) 1 Γ(α) 2πd 1 + λ βˆα+ 2 with
1 λ βˆ ≡ 1 + (log d − µ)2 (45) 2β 1 + λ This can be derived either by performing the integral directly or by noting that the result should be proportional to the posterior (Normal-Gamma, by conjugacy) of (m, τ ) after having observed a single data point, log d. Taken together, (36) (with F → F 0 ), (43), and (44) allow (25) to be calculated directly as a function of the variational parameters (α, β, λ, µ) of q(m, τ ). Were it not for the normalization constant arising from the restriction of discrete d, the updates resulting from optimizing this objective would reduce by conjugacy to simple paramater updates. With that additional term, we are forced to use a brute-force BFGS optimization step instead (see 2.3 below).
PLOS
6/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses 1.5.2
Forward-Backward Algorithm for HSMM: Notation
Here, we follow Yu and Kobayashi, mutatis mutandis. Define: ψt (m) = p(yt |zt = m)
observation probability
αt|x (m, d) = p(zt = m, dt = d|y1:x ) X γt|x (m) = αt|x (m, d)
condition on data marginalize out d
d
αt|t (m, d) ψt (m) = αt|t−1 (m, d) p(yt |y1:t−1 ) X = p(yt |y1:t−1 ) = αt|t−1 (m, d)ψt (m)
ψt∗ (m) = rt−1
m,d
=
X
γt|t−1 (m)ψt (m)
m
Z = p(y1:T ) =
T Y
!−1 rt
probability of data
t=1
Dt|x (m, d) = p(zt = m, dt = d, dt−1 = 1|y1:x )
→ (m, d)
Tt|x (n, m) = p(zt = n, zt−1 = m, dt−1 = 1|y1:x )
m→n
Et (m) = p(zt = m, dt = 1|y1:t ) = αt|t−1 (m, 1)ψt∗ (m) X A(m, n)Et (n) St (m) = p(dt = 1, zt+1 = m|y1:t ) =
m ends at t m starts at t + 1
n
and for the backward pieces p(yt:T |zt = m, dt = d) p(zt = m, dt = d|y1:T ) = p(zt = m, dt = d|y1:t−1 ) p(yt:T |y1:t−1 ) X p(y |z = m, d = 1) t:T t t−1 Et∗ (m) = = pm (d)βt (m, d) p(yt:T |y1:t−1 ) d p(yt:T |zt−1 = m, dt−1 = 1) X ∗ St∗ (m) = = Et (n)A(n, m) p(yt:T |y1:t−1 ) n
βt (m, d) =
These quantities then allow us to write the posterior probabilities αt|T (m, d) = βt (m, d)αt|t−1 (m, d)
posterior prob of (m, d)
Et∗ (n)A(n, m)Et−1 (m)
posterior prob of m → n
Dt|T (m, d) = βt (m, d)pm (d)St−1 (m)
posterior prob of entering (m, d) at t
Tt|T (n, m) =
1.5.3
Forward
From Yu and Kobayashi, we have the following forward pass: 1. α1|0 (m, d) = πm pm (d) for time 1; otherwise, use ∗ αt|t−1 (m, d) = St−1 (m)pm (d) + ψt−1 (m)αt−1|t−2 (m, d + 1)
(46)
2. Use αt|t−1 and ψt to get rt ; together, these give ψt∗ . Store rt 3. Calculate and save Et . 4. Calculate St . 5. Loop.
PLOS
7/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses 1.5.4
Backward
Again following Yu and Kobayashi: 1. Initialize βT (m, d) = ψT∗ (m) = ψT (m)rT 2. ψt∗ (m) = ψt (m)rt 3. Calculate βt (m, d) by ( βt (m, d) =
∗ St+1 (m)ψt∗ (m), d=1 βt+1 (m, d − 1)ψt∗ (m), d > 1
(47)
where, again, ∗ St+1 (m) =
X
∗ Et+1 (n)A(n, m)
(48)
βt+1 (n, d)pn (d)
(49)
n ∗ Et+1 (n) =
X d
4. Use the above equation to calculate Et∗ and St∗ . Save E ∗ . 1.5.5
Sufficient Statistics
Finally, we can calculate and return: log Z = −
T X
log rt
(50)
t=1
ξt (m) =
X
αt|T (m, d) = γt|T (m) =
d
X
αt|t−1 (m, d)βt (m, d)
(51)
d
∗ Ξt+1,t (n, m) = Tt+1|T (n, m) = Et+1 (n)A(n, m)Et (m) ( πm pm (d)β1 (m, d) t=1 Ct (m, d) = Dt|T (m, d) = St−1 (m)pm (d)βt (m, d) t > 1
(52) (53)
Note that, because Ξt+1,t is the joint absolute probability of the transition (i.e., not conditioned on dt = 1) it is not normalized when summing over m and n. The same is true for C when summing over m and d.
2 2.1
Inference Conjugate updates
For updates on the overdispersion, firing rate, and hyperparameter variables, we have the simple conjugate update rules depicted in Table 2. These can be derived either from free-form variational arguments, or exponential family rules, but are in any case straightforward [5]. If we assume a Ga(a, b) prior and Ga(α, β) variational posterior, all of these have a simple algebraic update in terms P of sufficient statistics. Here, we overload notation to write Ntu = m δu(m),u δt(m),t Nm , P θu = m δu(m),u θm , and make use of H, G, and F as defined in Eq (7) - (9). Indeed, our implementation caches F and G for a substantial speedup (at the cost of additional memory requirements). For autocorrelated noise, the update rules are for the j-th item in the series of autocorrelation.
PLOS
8/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses Table 2. Conjugate updates for Gamma distributions Variable θm su λ0u λzuk λxur c0 czk cxr d0 dzk dxr φju for j ∈ {1, . . . , τ }
2.2
α−a N P m 1 m δu(m),u 2 P N P t tu N P t tu ξtk t Ntu xtr U/2 U/2 U/2 U Eq [c0 ] U Eq [czk ] PU Eq [cxr ] τ 0 ≥j Nτ 0 u
β−b H0u(m) Ft(m)u(m) Gt(m)u(m) P m − log θm − 1] m δu(m),u [θ P θu H0u t Ftu Gtu P θu H0u t Ftuk Gtu cf. Section 2.2 P E [d λ q 0 0u − log λ0u − log d0 − 1] P u E [d λ − log λzuk − log dzk − 1] q zk zuk Pu E [d λ − q xr xur u P log λxur − log dxr − 1] P u Eq [c0 λ0u ] Pu Eq [czk λzuk ] u Eq [cxr P Qλxur ] ωτ 0 u F t(τ ) τ 0 ≤τ ζ 0 u,τ ≥j τ 0 6=j
τ u
Non-conjugate updates
We employ explicit optimization steps for two updates in our iterative algorithm. In each case, we employ an off-the-shelf optimization routine, though more efficient alternatives are likely possible. 2.2.1
Covariate firing rate effects
Because we do not restrict the covariates x(t) to be binary, Equation 6 no longer yields a conditional log probability for λx in the exponential family. This leaves us with a transcendental equation to solve for βxr . However, P Q from Table 2, we see that αxr t xtr , allowing usPto approximate Gt ≈ r (αxr /βxr )xtr as we have done above. Moreover, since the sum t Gt ∼ T G, we expect αx /βx ≈ 1 at the optimum for most reasonable datasets. We thus reparameterize βxur = αxur e−ur and write the relevant -dependent piece of the objective function L as X P X θm H0u(m) Ft(m)u(m) e− r u(m)r xt(m)u(m)r (54) L = axur ur − bxur e−ur − ur
m
We also supply the optimizer with the gradient: X P ∇ L = axur − bxur e−ur − xt(m)ur θm H0u Ft(m)u e− r ur xt(m)ur
(55)
m;u(m)=u
where we sum only over observations with u(m) = u. On each iterate, we then optimize L , initializing βx to the just-updated value of αx . In addition, we do not update firing rate effects for covariates separately, but optimize βxur for all r together. Again, more efficient schemes are no doubt possible.
2.3
Semi-Markov duration distribution
If the dwell time distributions of states in the semi-Markov model were truly continuous, the parameters (m, s2 ) of a lognormal distribution p(d|j) would have a Normal-Gamma conjugate prior, and updates would be closed-form. However, the requirement that the durations be integers and that there exist a maximal duration, D, over which these probability mass functions must normalize results in an extra term in Eq [log p] arising
PLOS
9/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses from the normalization constant. In this case, we must explicitly optimize for the parameters of the Normal-Gamma prior on (m, s2 ). However, unlike the case of 2.2.1, these parameters are only updated for one latent feature at a time. Since the Normal-Gamma distribution has only 4 parameters and the number of states of the latent variable is S = 2, this requires updating 4SK parameters, but only taken 4S = 8 at a time, a much more manageable task.
3
Experiments
Code for all algorithms and analyses is available at https://github.com/pearsonlab/spiketopics.
PLOS
10/11
Neuron’s Eye View: Inferring Features of Complex Stimuli from Neural Responses
References 1. Beal MJ. Variational algorithms for approximate Bayesian inference. University of London; 2003. 2. Yu SZ, Kobayashi H. Practical implementation of an efficient forward-backward algorithm for an explicit-duration hidden Markov model. Signal Processing, IEEE Transactions on. 2006;54(5):1947–1951. 3. Mitchell CD, Jamieson LH. Modeling Duration in a Hidden Markov Model with the Exponential Family. In: Acoustics, Speech, and Signal Processing; 1993. 4. Mitchell C, Harper M, Jamieson L. On the complexity of explicit duration HMM’s. IEEE Trans Audio Speech Lang Processing. 1995;3(3):213–217. 5. Blei DM, Jordan MI. Variational inference for Dirichlet process mixtures. Bayesian Anal. 2006;1(1):121–143.
PLOS
11/11