Intractable so approximate inference is needed. ⢠Bayesian inference for f and w, maximum likelihood for hyperparamete
Efficient Variational Inference for Gaussian Process Regression Networks
Efficient Variational Inference for Gaussian Process Regression Networks
Trung Nguyen and Edwin Bonilla Australian National University (ANU) National ICT Australia (NICTA) Presented by:Simon O'Callaghan
1
Outline Motivation
• Multi-output regression • Complex correlations
fig from Wilson et al 2
Outline Outline
• Gaussian process regression networks • Variational inference for GPRNs • Experiments • Summary
3
Preliminary: Gaussian Processes • A Gaussian process (GP) is specified by its mean and covariance function.
f (x) ⇠ GP(m(x), k(x, x0 ))
4
Gaussian process regression networks
• Motivation: prediction of outputs with complex correlations • Generative perspective
5
Inference for GPRNs (1) • Notations: X = {xn }N n=1
D = {yn = y(xn )}N n=1
✓ = {✓ f , ✓ w ,
y,
f
= 0}
• Bayesian formulation: • prior
p(f , w|✓ f , ✓ w ) =
i,j
• likelihood
p(D|f , w,
Y
y) =
Y n
N (fj ; 0, f )N (wij ; 0, w )
N (yn ; W(xn )f (xn ),
2 y I)
• Each latent and weight function is an independent GP 6
Inference for GPRNs (2) • Posterior
p(D|f , w)p(f , w) p(f , w|D, ✓) = R p(D|f , w)p(f , w)df dw
• Intractable so approximate inference is needed • Bayesian inference for f and w, maximum likelihood for hyperparameters • Variational messing passing was used in the original paper 7
Inference for GPRNs (3) • Variational inference: find the closest tractable approximation of the posterior (in KL divergence)
fig from Bishop 2006
• Optimization: minimizing the KL divergence is equivalent to maximizing the evidence lower bound (ELBO):
L(q) = Eq [log p(D|f , w)] + Eq [log p(f , w)] + Hq [q(f , w)] | {z } | {z } expected log joint
entropy
8
Inference for GPRNs (4) • Mean-field approximation q(f , w) =
Q Y
P Y
N (fj ; µfj , ⌃fj ) N (wij ; µwij , ⌃wij ) {z } i=1 | {z } j=1 | q(fj )
fj = [fj (x1 ), . . . , fj (xN )]T
q(wij )
wij = [wij (x1 ), . . . , wij (xN )]T
• O(N 2 ) variational parameters for covariance matrix of each factor
9
Inference for GPRNs (5) • Mean-field results • Exact ELBO and analytical solutions for the variational parameters ⌃ fj = K f 1 + 1
⌃wij = Kw +
P 1 X 2 y i=1
diag(µwij • µwij + Var(wij ))
1
diag(µfj • µfj + Var(fj )) 2
1
1
y
• Only O(N ) parameters needed for each factor 10
Inference for GPRN (6) • Nonparametric approximation z Q K Y X 1 (k) N (fj ; µfj , q(f , w) = K {z j=1 | k=1
(k)
q(fj
)
q (k) (f ,w)
}| P Y 2 N (wij ; µwij , k I)) {z } i=1 | (k)
q(wij )
{
2 k I))
}
• Each component is an isotropic Gaussian • O(KN ) variational parameters only (K < 5) 11
Inference for GPRN (7) • NPV results • Analytical lower bound for ELBO • Previous method used second-order approximation L(q)
K X 1 Eq(k) [log p(D|f , w)]Eq(k) [log p(f , w)] K {z } | k=1
analytically tractable as in MF
K K X X 1 1 log N (µ(k) ; µ(j) , ( K K j=1 k=1 | {z Hq [q(f ,w)]
2 k
+
2 j )I)
} 12
Inference for GPRNs (8): Summary • Two families of distributions for variational inference
• O(N ) variational parameters (c.f. O(N 2 ) for standard variational Gaussian)
• Approximations of relatively complex posteriors • Closed-form ELBO which allows model selection and learning of hyperparameters
13
Experiments (1)
• Datasets • Jura: prediction of heavy metal concentrations • Concrete: prediction of concrete qualities (slump, flow, compressive strength)
Cd
Ni
Zn
Slump
Flow
CS
?
?
?
?
14
Experiments (2)
Mean absolute error (MAE)
0.55 0.5 0.45 0.4 0.35 0.3 0.25
Standardized mean squared error (SMSE)
1.4
IGP MF NPV1 NPV2 NPV3
1.2 1 0.8 0.6 0.4 0.2 0
0.2
Jura (MAE)
IGP MF NPV1 NPV2 NPV3
Slump
Flow
Compressive Strength
Concrete (SMSE)
15
Summary • GPRNs for input-dependent (adaptive) correlations • Two tractable and statistically efficient families of variational distributions • Future work: • Simplify the GPRN model for less intensive inference without losing its flexibility • Extend/apply GPRN to other multi-task problems, e.g., classification, preferences • Scalability issues 16
Questions?
• Thank you!
17