Learning Implicit Generative Models Using Differentiable Graph Tests

109 downloads 0 Views 3MB Size Report
Sep 4, 2017 - show how two such classical tests, the Friedman-. Rafsky and k-nearest ...... [9] Arthur Gretton, Karsten M Borgwardt, Malte J. Rasch, Bernhard ...
arXiv:1709.01006v1 [stat.ML] 4 Sep 2017

Learning Implicit Generative Models Using Differentiable Graph Tests Josip Djolonga ¨ ETH Zurich [email protected]

Andreas Krause ¨ ETH Zurich [email protected]

Abstract

and fixed distribution Q0 , typically Gaussian or uniform, and then passing it through some differentiable function f θ parametrized by some vector θ to generate x = f θ (z) ∼ Q. The goal is then to optimize the parameters θ of the mapping θ so that Q is as close as possible to some target distribution P, which we can access only via iid samples. The approach that we undertake in this paper is that of defeating statistical two-sample tests. These tests operate in the following setting — given two sets of iid samples, X1 = {x1 , x2 , . . . , xn1 } from P, and X2 = {xn1 +1 , xn1 +2 , . . . , xn1 +n2 } from Q, we have to distinguish between the following hypotheses

Recently, there has been a growing interest in the problem of learning rich implicit models — those from which we can sample, but can not evaluate their density. These models apply some parametric function, such as a deep network, to a base measure, and are learned end-to-end using stochastic optimization. One strategy of devising a loss function is through the statistics of two sample tests — if we can fool a statistical test, the learned distribution should be a good model of the true data. However, not all tests can easily fit into this framework, as they might not be differentiable with respect to the data points, and hence with respect to the parameters of the implicit model. Motivated by this problem, in this paper we show how two such classical tests, the FriedmanRafsky and k-nearest neighbour tests, can be effectively smoothed using ideas from undirected graphical models – the matrix tree theorem and cardinality potentials. Moreover, as we show experimentally, smoothing can significantly increase the power of the test, which might of of independent interest. Finally, we apply our method to learn implicit models.

1

H0 : P = Q

vs

H1 : P 6= Q.

The tests that we consider start by defining a function T : (Rd )n1 × (Rd )n2 → R that should result in a low value if the two samples come from different distributions. Then, the hypothesis H0 is rejected at significance level α ∈ [0, 1] if T ( X1 , X2 ) is lower than some threshold tα , which is computed using a permutation test, as explained in Section 2. Going back the original problem, one intuitive approach would be to maximize the expected statistic n n +n Exi ∼ P,zi ∼Q0 [ T ({xi }i=1 1 , { f θ (zi )}i=1 n 2 )] using stochas1 tic optimization over the parameters of the mapping f θ . However, this requires the availability of the derivatives ∂T/∂xi , which is unfortunately not always possible. For example, the Friedman-Rafsky (FR) and k-nearest neighbours (k-NN) tests, which have very desirable statistical properties (including consistency and convergence of their statistics to f -

Introduction

The main motivation for our work is that of learning implicit models, i.e., those from which we can easily sample, but can not evaluate their density. Formally, we can generate a sample from an implicit distribution Q by first drawing z from some known 1

divergences), can not be cast in the above framework as they use the output of a combinatorial optimization problem. Our main contribution is the development of differentiable versions of these tests that remedy the above problem by smoothing their statistics. We moreover show, similarly to these classical tests, that our tests are asymptotically normal under certain conditions, and derive the corresponding t-statistic, which can be evaluated with minimal additional complexity. Our smoothed tests can have more power over their classical variants, as we showcase with numerical experiments. Finally, we experimentally learn implicit models in Section 5.

used by Bellemare et al. [14]. Other approaches for learning implicit models that do not depend on two sample tests have been developed as well. For example, one approach is by estimating the log-ratio of the distributions [15]. Another approach, that has recently sparked significant interest, and can be also seen as estimating the log-ratio of the distributions, are the generative adversarial networks (GAN) of Goodfellow et al. [16], who pose the problem as a two player game. One can, as done in [12], combine GANs with two sample tests by using them as feature matchers at some layer of the generating network [17]. Nowozin et al. [18] minimize an arbitrary f -divergence [19] using a GAN framework, which can be related to our approach, because the limit of our tests converge to specific f -divergences, as explained in Section 2. For an overview of various approaches to learning implicit models we direct the reader to Mohamed and Lakshminarayanan [20].

Related work. The problem of two-sample testing for distributional equality has received significant interest in statistics. For example, the celebrated Kolmogorov-Smirnov test compares two one dimensional distributions by taking the maximal difference of the empirical CDFs. Another one-dimensional test is the runs test of Wald and Wolfowitz [1], which has been extended to the multivariate case by Friedman and Rafsky [2] (FR). It is exactly this test, together with k-NN test originally suggested in [3] that we analyze. These tests have been analyzed in more detail by Henze and Penrose [4], and Henze [5], Schilling [6] respectively. Their asymptotic efficiency has been discussed by Bhattacharya [7]. Chen and Zhang [8] considered the problem of tie breaking when applying the FR tests to discrete data and suggested averaging over all minimal spanning trees, which can be seen as as special case of our test in the low-temperature setting. A very prominent test that has been more recently developed is the kernel maximum mean discrepancy (MMD) test of Gretton et al. [9], which we compare with in Section 5. The test statistic is differentiable and has been used for learning implicit models by Li et al. [10], Dziugaite et al. [11]. Sutherland et al. [12] consider the problem of learning the kernel by creating a t-statistic using a variance estimator. Moreover, they also pioneered the idea of using tests for model criticism — for two fixed distributions, one optimizes over the parameters of the test (the kernel used). The energy test of Sz´ekely and Rizzo [13], a special case of the MMD test, has been

2

Classical Graph Tests

Let us start by introducing some notation. For any set X = {x1 , x2 , . . . , xn } of points in Rd , we will denote by G( X ) = ( X, E) the complete directed graph1 defined over the vertex set X with edges E. We will moreover weigh this graph using some function d : Rd × Rd → [0, ∞), e.g. a natural choice would be d(x, x0 ) = kx − x0 k. Similarly, we will use d(e) for the weight of the edge e under d(·, ·). For any labelling of the vertices π : X → {1, 2}, and any edge e ∈ E with adjacent vertices i and j we define2 ∆π (e) = Jπ (i ) 6= π ( j)K, i.e., ∆π (e) indicates if its end points of e have different labels under π. Remember that we are given n1 points X1 = {x1 , x2 , . . . , xn1 } from P, and n2 points X2 = {xn1 +1 , xn1 +2 , . . . , xn1 +n2 } from Q. In the remaining of the paper we will use n = n1 + n2 for the total number of points. The tests are based on the following four-step strategy. (i) Pool the samples X1 and X2 together into X = X1 ∪ X2 = {x1 , x2 , . . . , xn1 +n2 }, and create the 1 For the FR test we will arbitrarily choose one of the two edges for each pair of nodes. 2 We use the Iverson bracket JSK that evaluates to 1 if S is true and 0 otherwise.

2

graph G( X ). Define the mapping π ∗ : X → {1, 2} evaluating to 1 on X1 and to 2 on X2 .

FR D1/2 NN D1/2

1.5

(ii) Using some well-defined algorithm A choose a subset U ∗ = A(G( X )) of the edges of this graph with the underlying motivation that it defines some neighbourhood structure.

f (x)

1

0.5

(iii) Count how many edges in U ∗ connect points from X1 with points from X2 , i.e., compute the statistic Tπ ∗ (U ∗ ) = ∑e∈U ∗ ∆π ∗ (e).

0 0

2

1

3

4

5

x

(iv) Reject H0 for small values of Tπ ∗ (U ∗ ).

Figure 1: The functions generating the f -divergences.

These tests condition on the data and are executed as permutation tests, so that the critical value in step (iv) is computed using the quantiles of Eπ ∼ H0 Tπ (U ∗ ), where π : X → {1, 2} is drawn uniformly at random from the set of (n1n+n2 ) labellings that map exactly n1 1 points from X to 1. Formally, the p-value is given as Eπ ∼ H0 [JTπ ∗ (U ∗ ) ≥ Tπ (U ∗ )K]. We are now ready to introduce the two tests that we consider in this paper, which are obtained by using a different neighbourhood selection algorithm A in step (ii).

equation, we obtain that 1 − Tπ ∗ (U ∗ ) n2n1 +nn22 converges 1 almost surely to the following f -divergence [19] DαFR ( P k Q) =

1 4α(1 − α)



(2α − 1)2 . 4α(1 − α)

Z

(αp(x) − (1 − α)q(x))2 dx αp(x) + (1 − α)q(x)

In [23] it is also noted that if n1 = n2 , then α = 1/2 R ( p(x)−q(x))2 and in that case D1/2 is equal to 2 dx, p(x)+q(x) which is known as the symmetric χ2 divergence.

Friedman-Rafsky (FR). This test, developed by Friedman and Rafsky [2], uses the minimumspanning tree (MST) of G( X ) as the neighbourhood structure U ∗ , which can be computed using the classical algorithms of Prim [21] and Kruskal [22] in time O(n2 log n). If we use d(xi , x j ) = kxi − x j k, the problem is also known as the Euclidean spanning tree problem, and in this case Henze and Penrose [4] have proven that the test is consistent and has the following asymptotic limit.

k-nearest-neighbours (k-NN). Maybe the most intuitive way to construct a neighbourhood structure is to connect each point x j ∈ X to its k nearest neighbours. Specifically, we will add the edge xi → x j to U ∗ iff xi is one of the k closest neighbours of x j as measured by d(x, x0 ). If one uses the Euclidean norm, then the asymptotic distribution and the consistency of the test have been proven by Schilling [6]. These results has been extended to arbitrary norms by Henze Theorem 1 ([4]). If d(x, x0 ) = kx − x0 k and n1 /(n1 + [5], who also proved the limiting behaviour of the n2 ) → α ∈ (0, 1), then it almost surely holds that statistic as n → ∞. Tπ ∗ (U ∗ ) → 2α(1 − α) n1 + n2

Z

Theorem 2 ([5]). If n1 /(n1 + n2 ) → α ∈ (0, 1), then T (U ∗ ) 1 − (nπ∗+n )k converges in probability to

p(x)q(x) dx, αp(x) + (1 − α)q(x)

1

where p and q are the densities of P and Q.

2

DαNN ( P k Q) ≡

Z

α2 p2 ( x ) + (1 − α )2 q2 ( x ) dx, αp(x) + (1 − α)q(x)

As noted by Berisha and Hero [23], after some algebraic manipulation of the right hand side of the above where p and q are the continuous densities of P and Q. 3

As for the FR test, we can also re-write the limit as an f -divergence3 corresponding to f (t) = (α2 t2 + (1 − α))/(αt + (1 − α)). Moreover, if we compare the integrands in DαFR and DαNN , we see that they are related and they differ by the term 2α(1 − α) p(x)q(x) in the numerator. The fact that they are closely related can be also seen from Figure 1, where we plot the corresponding f -functions for the n1 = n2 case.

distribution will concentrate on the MAP configurations as λ → 0. Once we have fixed the model, the strategy is clear — replace the statistic Tπ ∗ (U ∗ ) with its expectation EU [ Tπ ∗ (U )], which results in the following smooth statistic Tπ ∗ (U ∗ ) −→ Tπλ∗ ≡ EU ∼ P(·|d,λ) [ Tπ ∗ (U )]

=

∑ ∆π∗ (e)µ(d/λ)e ,

e∈ E

3

where µ(d/λ) are the marginal probabilities of the edges, i.e., [µ(d/λ)]e = EP(U |d/λ) [Je ∈ UK]. Hence, we can compute the statistic as long as we can perform inference in (2). To compute its derivatives we can use the fact that (2) is a member of the exponential family. Namely, leveraging the classical properties of the log-partition function [24, Prop. 3.1], we obtain the following identities

Differentiable Graph Tests

While the tests from the previous section have been studied from a statistical perspective, we can not use them to train implicit models because the derivatives ∂T/∂xi are either zero or do not exist, as T takes on finitely many values. The strategy that we undertake in this paper is to smooth them into continuously differentiable functions by relaxing them to expectations in natural probabilistic models. To motivate the models we will introduce, note that for both the k-NN and the FR test, the optimal neighbourhood is the solution to the following optimization problem

µ(d/λ) = ∇ A(−d/λ), and ∂µ(d/λ)e = EP(U |d/λ) [J{e, e0 } ⊆ UK] ∂µ(d/λ)e0

(3)

− µ(d/λ)e µ(d/λ)e0 .

Thus, if we can compute both first- and second-order (1) moments under (2), we get both the smoothed statisU ⊆ E e ∈U tic and its derivative. We show how to do this for the k-NN and FR tests in Section 4. E where ν : 2 → {0, 1} indicates if the set of edges is valid, i.e., if every vertex has exactly k neighbours in the k-NN case, or if the set of edges forms a poly-tree A smooth p-value. Even though one can directly λ in the MST case. Moreover, note that once we fix n1 use the smoothed test statistic Tπ ∗ as an objective when learning implicit models, it does not necesand n2 , the optimization problem (1) depends only sarily mean that lower values of this statistic result on the edge weights d(e), which we will concatenate in higher p-values. Remember that to compute a | E | in an arbitrary order and store in the vector d ∈ R . p-value, one has to run a permutation test by comWe want to design a probability distribution over U λ under random draws of the puting quantiles of T π that focuses on those configurations U that are both feasible and have a low cost for problem (1). One permutation π ∼ H0 . However, as this procedure is such natural choice is the following Gibbs measure not smooth and can be costly to compute, we suggest as an alternative that does not suffer from these − ∑e∈U d(e)/λ− A(−d/λ) problems the following t-statistic P(U | d/λ) = e ν (U ) , (2) U ∗ = arg min

∑ d(e) s.t. ν(U ) = 1,

T λ∗ − Eπ ∼ H0 [ Tπλ ] tλπ ∗ = πq . (4) where λ is the so-called temperature parameter, and Vπ ∼ H0 [ Tπλ ] A(−d/λ) is the log-partition function that ensures that the distribution is normalized. Note that U ∗ is a MAP configuration of this distribution (2), and the The same strategy has been undertaken for the FR and k-NN tests in [2, 4, 6]. Before we show to com3 This f does not vanish at one, but we can simply shift it. pute the first two moments under H0 , we need to 4

define the matrix Π holding the second moments of that we can fit the suggested tests in this framethe variables ∆π (e). work if we set µi,j = 21 (µ(d/λ)i→ j + µ(d/λ) j→i ) Lemma 1 ([2]). The matrix Π ∈ R| E|×| E| with entries and bi,j = ∆π ∗ ({i, j}). Then, using the conditions of Barbour and Eagleson [26], we obtain the followΠe,e0 = Eπ ∼ H0 [∆π (e)∆π (e0 )] is equal to ing bound on the deviation from normality.  2n n 0 ), or 1 2  if δ ( e ) = δ ( e  Theorem 4. Let n1 /(n1 + n2 ) → α ∈ (0, 1), and define  n ( n −1) n1 n2 0 if |δ(e) ∩ δ(e )| = 1, or Πe,e0 = n(n−1) • S2 = ∑i,j,k µi,j µi,k , i.e., the expected number of edges    4n1 n2 (n1 −1)(n2 −1) if δ(e) ∩ δ(e0 ) = ∅, sharing a vertex, n(n−1)(n−2)(n−3) • S3 = ∑i,j,k,m µi,j µi,k µi,m , i.e., the expected number of 3 stars, and

where δ(e) is the set of vertices incident to the edge e ∈ E. Theorem 3. Assume that all valid configurations U satisfy |U | = m, i.e. that ν(U ) 6= 0 implies |U | = m.4 Then, the first two moments of the statistic under H0 are

• L4 = ∑i,j,k,m µi,j µ j,k µk,m , i.e., the expected number of paths with 4 nodes.

Then, the Wasserstein distance between the permutation null Eπ ∼ H0 [ Tπλ (U ∗ )] and the standard normal is of order  2 n2 n 3 3 Vπ ∼ H0 [ Tπλ∗ ] = µ(d/λ) T Πµ(d/λ) − 4 2 1 2 2 m2 . O (nk + kS2 + S3 + L4 )/σ . n ( n − 1) Let us analyze the above bound in the setting that While the computation of the mean is trivial, it we will use it — when n1 = n2 . First, let us look seems that the computation of the variance needs at the variance, as formulated in Lemma 2. The last O(| E|2 ) operations. However, we can simplify its term can be ignored as it is always non-negative becomputation to O(| E|) using the following result. cause χ2 ≥ 4χ1 (shown in the appendix). Because ∑e∈δ(v) µe ≥ 1, it follows that the variance grows as 4(n1 −1)(n2 −1) n2 and χ = . Lemma 2. Define χ1 = n(nn1− 2 1) (n−2)(n−3) Ω(n). Thus, without any additional assumption on Then, the variance can be computed as the growth of the neighbourhoods, we have asymptotic normality as n → ∞ if the numerator is of orσ 2 = χ1 (1 − χ2 ) ∑ ( ∑ µ e )2 der o (n1.5 ). For example, that would be satisfied if v e∈δ(v) the largest neighbourhood maxi ∑e∈δ(i) µe grows as + χ1 χ2 ∑ µe µe0 + χ1 (χ2 − 4χ1 )m2 , o (n1/6 ). Note that in the low temperature setting eke0 (when λ → 0), the coordinates of µ will be very close where ∑eke sums over all pairs of parallel edges, i.e., those to either zero or one. As observed by Friedman and Rafsky [2], in this case S2 = O(1) as the nodes of connecting the same end-points. both the k-NN and MST graphs have nodes whose λ Approximate normality of tπ ∗ . To better motivate degree is bounded by a constant independent of n the use of a t-statistic, we can, similarly to the argu- as n → ∞ [27]. We also observe experimentally in ments in [2, 4, 6], show that it is is close to a nor- Section 5 that the distribution gets closer to normality mal distribution by casting it as a generalized correla- as λ decreases. Eπ ∼ H0 [ Tπλ∗ ] = 2mn1 n2 /n(n − 1), and

tion coefficient [25, 3]. Namely, these are tests whose statistics are the form form κ = ∑in=1 ∑nj=1 µi,j bi,j , and whose critical values are computed using the distribution of ∑in=1 ∑nj=1 µi,j bπ (i),π ( j) , where π is a random permutation on {1, 2, . . . , n}. It is easily seen 4 Note

4

The Differentiable k-NN and Friedman-Rafsky Tests

In this section, we discuss these two tests in more detail and show to efficiently compute their statistics.

that we have m = kn for k-NN and m = n − 1 for FR.

5

Remember that to compute and optimize both Tπλ∗ and tλπ ∗ we have to be able to perform inference in the model P(U ) = exp(− ∑e d(e)/λ − A(−d/λ))ν(U ), by computing the first and- second-order moments of the edge indicator variables. We would stress that, in the learning setting that we consider n refers to the number of data-points in a mini-batch.

of the distances to all other points using si , and then sum up only those positions that correspond to points from the other sample. One interpretation of the loss is the following — maximize the number of incorrect predictions if we are to estimate the label π (i ) from xi using a soft 1-nearest neighbour approach. Furthermore, we can also make a clear connection between the smooth 1-NN test and neighbourhood component analysis (NCA) [30]. Namely, we can see NCA as learning a mapping h : x → Ax so that the test distinguishes (by minimizing Tπλ∗ ) the two samples as best as possible after applying h on them. The extension of NCA to k-NN [31] can be also seen as minimizing the test statistic for a particular instance of their loss function.

k-NN. The constraint ν(·) in this case requires the total number of edges in U incoming at each node to be exactly k. First, note that the problem completely separates per node, i.e., the marginals of edges with different target vertices are independent. Formally, if we denote by Ui the set of edges incoming at vertex i, then Ui and Uj are independent for i 6= j. Hence, for each node i separately, have to perform inference in P(Ui ) ∝ exp(−



j∈Ui

d(xi , x j )/λ)J|Ui | = kK,

Friedman-Rafsky. The model that we have to perform inference in for this test seems extremely complicated and intractable at first because the constraint has the form ν(U ) = JU forms a spanning treeK. First, note that if d/λ had all entries equal to a constant γ, we have that A(−d/λ) = (1 − n)γ + log cG(X ) , where cG(X ) is the number of spanning trees in the graph G( X ), and can be computed using Kirchoff’s (also known as the matrix-tree) theorem. To treat the weighted case, we use the approach of Lyons [32], who has showed that the above model is a determinantal point process (DPP), so that marginalization can be done exactly as follows. First, create the incidence matrix A ∈ {−1, 0, +1}(n−1)×|E| of the graph G( X ) after removing an arbitrary vertex, and   construct its Laplacian L = Adiag exp(−d/λ) A T . Then, if we compute   H = L−1/2 Adiag exp(−d/(2λ)) , the distribution P(U ) is a DPP with kernel matrix K = H T H, implying that for every W ⊆ E

which is a special case of the cardinality potentials considered by Tarlow et al. [28], Swersky et al. [29]. Swersky et al. [29] consider the same model, and note that we can compute all marginals in time O(nk) using the algorithm in [28], which works by re-writing the model as a chain CRF and running the classical forward-backward algorithm. Hence, the total time complexity to compute the vector µ(d/λ) is O(n2 k). Moreover, as marginalization requires only simple operations, we can compute the derivatives with any automatic differentiation software, and we thus do not provide formulas for the second-order moments. In [29] the authors provide an approximation for the Jacobian, which we did not use in our experiments, but instead we differentiate through the messages of the forward-backward algorithm. As a concrete example, let us work out the simplest case — the k-NN test with k = 1. In this case, the smoothed statistic reduces to n

Tπλ∗ (x1 , . . . , xn ) =



i =1

n



EP(U |d/λ) [JW ⊆ UK] = det KW ,

s i ( x1 , . . . , x n ) j ,

j =1 π ∗ (i )6=π ∗ ( j)

where KW is the |W | × |W | submatrix of K formed by where si (x1 , . . . , xn ) = softmax(− ⊗l 6=i kxi − xl k/λ). the rows and columns indexed by W. Thus, we can In other words, for each i you compute the softmax easily compute all marginals and the smoothed test 6

5

statistic and its derivatives using (3) as µi→ j = e−d(xi ,x j )/λ (ui − u j ) T L−1 (ui − u j ), and

Experiments

We implemented our methods in Python using the PyTorch library. For the k-NN test, we have adapted the code accompanying [29]. Throughout this section we used a 10 dimensional normal as Q0 , drew where ui is the vector with coordinates equal to zero, samples of equal size n = n2 , and used the `2 norm 1 except the i-th coordinate which is one. Note that if d(x, x0 ) = kx − x0 k2 as a weighting function. We we first compute the inverse L−1 , all quantities of the provide additional details in Appendix B. form L−1 (ui − u j ) can be computed in time O(n) as the vectors ui have a single non-zero entry, for a total Power as a function of λ and d. In our first expericomplexity of O(n3 ). ment we analyze the effect of the smoothing strength To speed up this computation we can leverage on the power of our differentiable tests. In addition to the existing theory on fast solvers of Laplacian systhe classical FR and k-NN tests, we have considered tems. Let us first create from G( X ) the graph eG ( X ) the unbiased MMD test [9] with the squared exponenthat has the same structure as G( X ), but with edge tial kernel (as implemented in Shogun [37] using the weights e−d(e)/λ instead of d(e). Hence, in this code from [12]), and the energy test [13]. The probgraph, a large weight between x and x0 indicates lem that we consider, which is challenging in high that these two points are similar to one another. In dimensions, is that of differentiating the distribueG ( X ), the marginals µe are also known as effective tion N (0, I ) from N ((µ, 0, . . . , 0), diag(σ2 , 1, . . . , 1)). resistances5 . Spielman and Srivastava [34] provide This setting was considered to be fair in [38], as the a method to compute all marginals at once in time KL divergence between the distribution is constant that is O˜ (rn2 /ε2 ), where ε is the desired relative pre- irrespective of the dimension. To set the smoothcision and r = λ1 (maxe d(e) − mine d(e)). The idea ing strength and the bandwidth of the MMD ker  is to first solve for Z T = L−1 Adiag exp(−d/2λ) R nel (in addition to the median heuristic) we used √ √ where R ∈ {−1/ k, +1/ k}| E|× p is a random pro- the same strategy as in [38] by setting λ = dγ for jection√matrix √ with elements chosen uniformly from varying γ ∈ [0, 1]. The results are presented in Fig{−1/ k, +1/ k} and p = O(log n/ε2 ). Then, the ure 2, where can observe that (i) our test have simsuggested approximation is µi→ j ≈ k Z (ui − u j )k2 . ilar results with MMD for shift-alternatives, while While computing Z na¨ıvely would take O(n3 + n2 p), performing significantly better for scale alternatives, one achieves the claimed bound with the Laplacian and (ii) by varying the smoothing parameter we can significantly increase the power of the test. In the solver of Spielman and Teng [35]. As an extra benefit, the above connection provides third column we present only the best performing an alternative interpretation of the smoothed FR test. MMD, while we present the remaining results in ApNamely, assume that we want to create a spectral spar- pendix B. Note that we expect the power to go to zero sifier [36] of eG ( X ), which should contain significantly as the dimension increases [7, 38]. less edges, but be a good summary of the graph by having a similar spectrum. Spielman and Srivastava Learning. As we have already hinted in the intro[34] provide a strategy to create such a sparsifier by duction, we stochastically optimize sampling edges randomly, where edge e is sampled n n +n max.θ Exi ∼ P,zi ∼Q0 [tλπ ∗ ({xi }i=1 1 , { f θ (zi )}i=1 n 2 )] proportional to µe . Hence, by optimizing Tπλ∗ we are 1 encouraging the constructed sparsifier of eG ( X ) to using the Adam [39] optimizer. To optimize, we draw have in expectation as many edges as possible conat each round n1 samples from the true distribution necting points from X1 with points from X2 . P, n2 = n1 samples from the base measure Q0 , and 5 For additional properties of the effective resistances see [33]. then plug them in into the smoothed t-statistic. d(xi ,x j )+d(xk ,xl ) ∂µi→ j λ ((ui − u j )T L−1 (uk − ul ))2 , = e− ∂µk→l

7

References

The first experiment we perform, with the goal of understanding the effects of λ, is on the toy two moons dataset from scikit-learn [40]. We show the results in Figure 3. From the second row, showing the estimated p-value versus the correct one (from 1000 random permutations) at several points during training, we can indeed see that the permutation null gets closer to normality as λ decreases. Most importantly, note that the relationship is monotone, so that we would expect the optimization to not be significantly harmed if we use the approximation. Qualitatively, we can observe that the solutions have the general structure of P, and that they improve as we decrease λ — the symmetry is better captured and the two moons get better separated.

[1] Abraham Wald and Jacob Wolfowitz. On a test whether two samples are from the same population. Annals of Mathematical Statistics, 11(2): 147–162, 1940. [2] Jerome H Friedman and Lawrence C Rafsky. Multivariate generalizations of the waldwolfowitz and smirnov two-sample tests. Annals of Statistics, pages 697–717, 1979. [3] Jerome H Friedman and Lawrence C Rafsky. Graph-theoretic measures of multivariate association and prediction. Annals of Statistics, pages 377–391, 1983. [4] Norbert Henze and Mathew D Penrose. On the multivariate runs test. Annals of Statistics, pages 290–298, 1999.

MNIST. Finally, we have trained several models on the MNIST [41] dataset, which we present in Figure 4. We can observe that despite the high (784) dimensional data and the fact that we use the distance directly on the pixels, the learned models generate digits that look mostly realistic and are competitive with those obtained using MMD [10, 11].

6

[5] Norbert Henze. A multivariate two-sample test based on the number of nearest neighbor type coincidences. Annals of Statistics, pages 772–783, 1988. [6] Mark F Schilling. Multivariate two-sample tests based on nearest neighbors. Journal of the American Statistical Association, 81(395):799–806, 1986.

Conclusion

[7] Bhaswar B Bhattacharya. based two-sample tests. arXiv:1508.07530, 2015.

We have developed smoothed two-sample graph tests that can be used for learning implicit models. These tests moreover outperform their classical equivalents on the problem of two sample testing. We have shown how to compute them by performing inference in undirected models, and presented alternative interpretations by drawing connections to neighbourhood component analysis and spectral graph sparsifiers. In the last section we have experimentally showcased the benefits of our approach, and presented results from a learned model.

Power of grapharXiv preprint

[8] Hao Chen and Nancy R Zhang. Graph-based tests for two-sample comparisons of categorical data. Statistica Sinica, pages 1479–1503, 2013. [9] Arthur Gretton, Karsten M Borgwardt, Malte J ¨ Rasch, Bernhard Scholkopf, and Alexander Smola. A kernel two-sample test. Journal of Machine Learning Research, 13(Mar):723–773, 2012.

[10] Yujia Li, Kevin Swersky, and Rich Zemel. Generative moment matching networks. In InterAcknowledgements. The research was partially supported by national Conference on Machine Learning (ICML), ERC StG 307036 and a Google European PhD Fellowship. 2015. 8

1.0

0.8 0.6 0.4 0.2 0.0

fr-0.0 fr-0.05 fr-0.1 fr-0.25 fr-0.5 fr-0.75 fr-1.0 fr-ct

0.8 0.6 0.4 0.2

1.0

Test power at α=0.05

knn-3-0.0 knn-3-0.05 knn-3-0.1 knn-3-0.25 knn-3-0.5 knn-3-0.75 knn-3-1.0 knn-3-ct

Test power at α=0.05

Test power at α=0.05

1.0

energy fr-0.5 knn-3-0.75 mmd-median

0.8 0.6 0.4 0.2

0.0 100

200

300

400

100

Dimension

200

300

400

100

Dimension

200

300

400

Dimension 1.0

0.8

0.8

0.8

0.6 0.4 0.2 0.0

Test power at α=0.05

1.0

Test power at α=0.05

Test power at α=0.05

(a) Power against the alternative (µ = 0.5, σ = 1) from n1 = n2 = 128 samples. 1.0

0.6 0.4 0.2 0.0

100

200

300

400

energy fr-0.0 knn-3-0.0 mmd-0.5

0.6 0.4 0.2 0.0

100

Dimension

200

300

400

100

Dimension

200

300

400

Dimension

1.0

1.0

0.8

0.8

0.8

0.6 0.4 0.2

Test power at α=0.05

1.0

Test power at α=0.05

Test power at α=0.05

(b) Power against the alternative (µ = 0, σ = 3) from n1 = n2 = 128 samples.

0.6 0.4 0.2

200

300

Dimension

400

0.6 0.4 0.2 0.0

0.0 100

energy fr-0.0 knn-3-0.0 mmd-0.5

100

200

300

400

100

Dimension

200

300

400

Dimension

(c) Power against the alternative (µ = 0, σ = 3) from n1 = n2 = 256 samples.

Figure 2: Test power when comparing two normal distributions. In the first two columns we present the 3-NN and FR tests as we vary λ — we use fr-γ for λ = dγ , and fr-ct for the classical test (analogously for 3-NN). The legends presented in the first row are consistent across the respective columns. The last column compares the best performing of these tests with the best performing MMD tests (the remaining MMD plots are provided in Appendix B). Note that our smoothed tests have the largest power, and they significantly outperform their classical counterparts.

9

1.25

1.25

1.25

1.25

1.00

1.00

1.00

1.00

0.75

0.75

0.75

0.75

0.50

0.50

0.50

0.50

0.25

0.25

0.25

0.25

0.00

0.00

0.00

0.00

−0.25

−0.25

−0.25

−0.25

−0.75 −1.5

−0.50

−1.0

−0.5

0.0

0.5

1.0

1.5

2.0

2.5

−0.50

−0.75 −1.5

−1.0

−0.5

0.0

0.5

1.0

1.5

2.0

2.5

−1.0

−0.5

0.0

0.5

1.0

1.5

2.0

2.5

−0.75 −1.5

0.8 0.6 0.4

0.8 0.6 0.4

0.0

0.0

0.0

0.50

0.75

1.00

0.00

0.25

0.50

0.75

1.00

True p-value

(b) 1-NN with λ = 10.

0.5

1.0

1.5

2.0

2.5

0.4 0.2

True p-value

0.0

0.6

0.2

0.25

−0.5

0.8

0.2

0.00

−1.0

1.0

1.0

Estimated p-value

Estimated p-value

1.0

(a) Original data.

−0.50

−0.75 −1.5

Estimated p-value

−0.50

(c) 1-NN with λ = 1.

0.00

0.25

0.50

0.75

1.00

True p-value

(d) 1-NN with λ = 0.05.

Figure 3: The effect of varying λ on the learned model and the normality of the null statistic. Note that with decreasing λ we get closer to normality, and the learned distribution better models the true one.

(a) 1-NN with λ = 10 and n1 = 256.

(b) 1-NN with λ = 10 and (c) FR λ = 10 and n1 = 128. (d) FR λ = 5 and n1 = 128. n1 = 512.

Figure 4: Four different models trained on MNIST.

10

[11] Gintare Karolina Dziugaite, Daniel M. Roy, and Zoubin Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In Uncertainty in Artificial Intelligence [20] (UAI), 2015.

distribution from another. Journal of the Royal Statistical Society. Series B (Methodological), pages 131–142, 1966. Shakir Mohamed and Balaji Lakshminarayanan. Learning in implicit generative models. arXiv preprint arXiv:1610.03483, 2016.

[12] Dougal J Sutherland, Hsiao-Yu Tung, Heiko Strathmann, Soumyajit De, Aaditya Ramdas, Alex Smola, and Arthur Gretton. Generative [21] Robert Clay Prim. Shortest connection networks and some generalizations. Bell Labs Technical models and model criticism via optimized maxJournal, 36(6):1389–1401, 1957. imum mean discrepancy. In International Conference on Learning Representations (ICLR), 2016. [22] Joseph B Kruskal. On the shortest spanning subtree of a graph and the traveling salesman [13] G´abor J Sz´ekely and Maria L Rizzo. Energy problem. Proceedings of the American Mathematistatistics: A class of statistics based on distances. cal society, 7(1):48–50, 1956. Journal of Statistical Planning and Inference, 143(8): 1249–1272, 2013. [23] Visar Berisha and Alfred O Hero. Empirical non-parametric estimation of the fisher informa[14] Marc G Bellemare, Ivo Danihelka, Will Dabney, tion. IEEE Signal Processing Letters, 22(7):988–992, Shakir Mohamed, Balaji Lakshminarayanan, 2015. Stephan Hoyer, and R´emi Munos. The cramer distance as a solution to biased wasserstein gra[24] Martin J Wainwright and Michael I Jordan. dients. arXiv preprint arXiv:1705.10743, 2017. Graphical models, exponential families, and R variational inference. Foundations and Trends [15] Masashi Sugiyama, Taiji Suzuki, and Takafumi in Machine Learning, 1(1-2), 2008. Kanamori. Density ratio estimation in machine learning. Cambridge University Press, 2012. [25] Henry E Daniels. The relation between mea[16] Ian Goodfellow, Jean Pouget-Abadie, Mehdi sures of correlation in the universe of sample Mirza, Bing Xu, David Warde-Farley, Sherjil permutations. Biometrika, 33(2):129–135, 1944. Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In Advances in Neu- [26] AD Barbour and GK Eagleson. Random association of symmetric arrays. Stochastic Analysis and ral Information Processing Systems (NIPS), pages Applications, 4(3):239–281, 1986. 2672–2680, 2014. [17] Tim Salimans, Ian Goodfellow, Wojciech [27] Joseph E Yukich. Probability theory of classical Euclidean optimization problems. Springer, 2006. Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training gans. [28] Daniel Tarlow, Kevin Swersky, Richard S Zemel, In Advances in Neural Information Processing SysRyan Prescott Adams, and Brendan J Frey. Fast tems (NIPS), pages 2234–2242, 2016. exact inference for recursive cardinality models. Uncertainty in Artificial Intelligence (UAI), 2012. [18] Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f -GAN: Training generative neural samplers using variational divergence mini- [29] Kevin Swersky, Ilya Sutskever, Daniel Tarlow, Richard S Zemel, Ruslan R Salakhutdinov, and mization. In Advances in Neural Information ProRyan P Adams. Cardinality restricted boltzcessing Systems (NIPS), pages 271–279, 2016. mann machines. In Advances in Neural Informa[19] Syed Mumtaz Ali and Samuel D Silvey. A gention Processing Systems (NIPS), pages 3293–3301, eral class of coefficients of divergence of one 2012. 11

[30] Jacob Goldberger, Geoffrey E Hinton, Sam T [39] Diederik Kingma and Jimmy Ba. Adam: A Roweis, and Ruslan R Salakhutdinov. Neighmethod for stochastic optimization. In Interbourhood components analysis. In Advances national Conference on Learning Representations in Neural Information Processing Systems (NIPS), (ICLR), 2015. pages 513–520, 2005. [40] F. Pedregosa, G. Varoquaux, A. Gramfort, [31] Daniel Tarlow, Kevin Swersky, Laurent CharV. Michel, B. Thirion, O. Grisel, M. Blondel, lin, Ilya Sutskever, and Rich Zemel. Stochastic P. Prettenhofer, R. Weiss, V. Dubourg, J. Vank-neighborhood selection for supervised and underplas, A. Passos, D. Cournapeau, M. Brucher, supervised learning. In International Conference M. Perrot, and E. Duchesnay. Scikit-learn: Maon Machine Learning, pages 199–207, 2013. chine learning in Python. Journal of Machine Learning Research, 12:2825–2830, 2011. [32] Russell Lyons. Determinantal probability mea´ 98(1): [41] Yann LeCun, L´eon Bottou, Yoshua Bengio, and sures. Publications math´ematiques de l’IHES, 167–212, 2003. Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the [33] Ashok K Chandra, Prabhakar Raghavan, WalIEEE, 86(11):2278–2324, 1998. ter L Ruzzo, Roman Smolensky, and Prasoon Tiwari. The electrical resistance of a graph captures its commute and cover times. Computational Complexity, 6(4):312–340, 1996. [34] Daniel A Spielman and Nikhil Srivastava. Graph sparsification by effective resistances. SIAM Journal on Computing, 40(6):1913–1926, 2011. [35] Daniel A Spielman and Shang-Hua Teng. Nearly linear time algorithms for preconditioning and solving symmetric, diagonally dominant linear systems. SIAM Journal on Matrix Analysis and Applications, 35(3):835–885, 2014. [36] Daniel A Spielman and Shang-Hua Teng. Spectral sparsification of graphs. SIAM Journal on Computing, 40(4):981–1025, 2011. ´ Sonnenburg, Sebastian Henschel, Christian [37] SC Widmer, Jonas Behr, Alexander Zien, Fabio de ¨ Bona, Alexander Binder, Christian Gehl, VojtA Franc, et al. The shogun machine learning toolbox. Journal of Machine Learning Research, 11(Jun): 1799–1802, 2010. [38] Aaditya Ramdas, Sashank Jakkam Reddi, ´ Barnab´as Poczos, Aarti Singh, and Larry A Wasserman. On the decreasing power of kernel and distance based nonparametric hypothesis tests in high dimensions. In AAAI, 2015. 12

A

Proofs

Proof of theorem 3. The expectation of the statistic under H0 is (when π is a uniformly random labelling)

∑ µ(d/λ)e

e∈ E

Eπ [∆π (e)] = 2mn1 n2 /n(n − 1), | {z }

2n1 n2 /n(n−1)

where the inner expectation Eπ [∆π (e)] has been computed in [2]. We can also easily compute the variance as



e,e0 ∈ E

Covπ ∼ H0 [µe ∆π (e), µe0 ∆π (e0 )] =

4n21 n22 0 m2 . − µ µ E [ ∆ ( e ) ∆ ( e )] 0 e π π π ∼ H ∑ e 0 2 ( n − 1)2 n {z } | 0 e,e ∈ E | {z } Π 0 e,e

(5)

(Eπ ∼ H0 [ Tπλ∗ ])2

Proof of lemma 2. We can split the sum in the variance formula over all edge pairs into three groups as follows 4n1 n2 (n1 − 1)(n2 − 1) n1 n2 n1 n2 µe µe0 + ∑ ∑ µe µe0 + ∑ ( µ2 + µ e µ e ), n ( n − 1 ) n ( n − 1 )( n − 2 )( n − 3 ) n ( n − 1) e e e0 ⊥e e e ∼e | {z } | {z } | {z }

∑∑ 0 e

χ1

χ1 χ2

(6)

χ1

where ∑e0 ∼e sums over all edges e0 that share at least one vertex with e, and ∑e0 ⊥e sums over those edges that share no vertex with e, and e denote the reverse edge of e (if it exist, zero otherwise). Note that each term µe µe0 appears twice if e 6= e0 , as in the formula for the variance (5). Moreover, note that if δ(e) = δ(e0 ), then in the above formula the term µe µe0 (same for µe0 µe ) gets multiplied by 2χ1 = Πe,e0 , as it appears in both the first and the third term. Given that assumption that |U | = m under ν(·), we also know that m2 = ( ∑ µ e )2 =

µe µe0 + ∑ ∑ µ e µ e0 , ∑ ∑0 µe µe0 = ∑ ∑ 0 0

e

e

e e ∼e

e

e e ⊥e

so that eq. (6) can be simplified to χ1 ∑

µe µe0 ) + χ1 ∑(µ2e + µe µe ), ∑ µ e µ e 0 + χ1 χ2 ( m2 − ∑ ∑ 0

e e0 ∼e

e

e e ∼e

which be simplified to χ1 (1 − χ2 ) ∑

∑ µe µe0 + χ1 ∑(µ2e + µe µe ) + χ1 χ2 m2 .

e e0 ∼e

e

Now the result follows by observing that

∑( ∑ v

e∈δ(v)

µ e )2 =

µe µe0 + ∑ µ2e + ∑ µe µe . ∑∑ 0 e e ∼e

e

e

To understand why this holds, let us count how many times each term µe µe0 appears on both sides of the equality if we expand the lhs. If e 6= e0 and they share exactly one vertex, then the lhs will have two µe µe0 terms, as µe and µe0 will be multiplied only at the term corresponding to the shared vertex. On the other hand, if e = e0 we will again have two µe µe0 = µ2e terms, as we get one contribution from each end-point of e. 13

Finally, if e0 = e, we have a total of four µe µe0 terms, as we get two µe µe0 from each end-point. Thus, eq. (6) is equal to  χ1 (1 − χ2 ) ∑( ∑ µe )2 − ∑ µ2e − ∑ µe µe + χ1 ∑(µ2e + µe µe ) + χ1 χ2 m2 . v

e

e∈δ(v)

e

e

Finally, if we subtract 4χ21 m2 and simplify the expression we have χ1 (1 − χ2 ) ∑ ( v



e∈δ(v)

µe )2 + χ1 χ2 ∑ µ2e + χ1 χ2 ∑ µe µe + χ1 (χ2 − 4χ1 )m2 , e

e

which is exactly what is claimed in the theorem, if we observe that e and e are the only edges parallel to e.

≤ nn1−−21 , if and only if n1 n − 2n1 ≤ nn1 − n, which is equivalent to n1 ≥ 12 n. Similarly, we have nn−21 ≤ nn2−−31 iff nn2 − 3n2 ≤ nn2 − n − n2 + 1, which can be re-written as −2n2 ≤ −n + 1, i.e., n2 ≥ n2 − 21 . Combining these two inequalities proves the result. Proof that χ2 − 4χ1 ≥ 0 when n1 = n2 = n/2. First, note that

n1 n

Proof of Theorem 4. Let us compute an upper bound on the quantities in [26]. a1 = a2 =

1 n ( n − 1)

∑ µi,j = i,j

1 n(n − 1)(n − 2)

k n

∑ µi,j µi,k {z

1 n(n − 1)(n − 2)(n − 3)



1 n(n − 1)(n − 2)(n − 3)

a6 = a7 = a8 =

1 n(n − 1)(n − 2) 1 n ( n − 1)

µi,j µi,k µi,m {z S3

b3 =

n2 n31 + n1 n32 = Θ (1) n(n − 1)(n − 2)(n − 3)



}

µk,i µi,j µ j,m

b4 = 2

i,j,k,m

{z L4

n22 n21 = Θ (1) n(n − 1)(n − 2)(n − 3)

}

∑ µ2i,j µi,k = O(a2 )

b5 = b2

i,j,k

∑ µ3i,j = O(a1 )

b6 = b1

i,j

1 n(n − 1)(n − 2) 1 n ( n − 1)

n2 n21 + n1 n22 = Θ (1) n(n − 1)(n − 2)

i,j,k,m

| a5 =

b2 =

}

S2

| a4 =

2 n2 n1 = Θ (1) n ( n − 1)

i,j,k

| a3 =

b1 =



b7 =

µi,j µi,k µ j,k

i,j,k,m

∑ µ2i,j = O(a1 )

n2 n1 n2 + n1 n2 n1 = Θ (1) n(n − 1)(n − 2)

b8 = b1 .

i,j

14

Then, the upper bound has the form 1 4 3 n ( a1 + σ3 |{z} k3 /n3

n3 (

a1 a2 |{z}

+

O(kS2 /n4 )

a3 + a4 ) (b13 + b1 b2 + b3 + b4 ) + |{z} |{z} | {z }

O(S3 /n4 )

O( L4 /n4 )

O (1)

b6 , a5 + a1 a8 ) (b5 + b1 b8 ) +n2 a6 |{z} |{z} |{z} |{z} | {z } 

O(S2 /n3 )

O(k2 /n2 )

O(k/n) O(1)

O (1)

which can be simplified to O(

  1 3 1 nk + kS2 + S3 + L4 + S2 + nk2 + k/n ) = O 3 (nk3 + kS2 + S3 + L4 ) , 3 σ σ

which is what is claimed in the theorem.

MMD mmd-0.0 mmd-0.05 mmd-0.1 mmd-0.25 mmd-0.5 mmd-0.75 mmd-1.0 mmd-median

Test power at α=0.05

0.8 0.6 0.4 0.2 0.0

1.0

1.0

0.8

0.8

Test power at α=0.05

B.1

Experiments

Test power at α=0.05

B

0.6 0.4 0.2 0.0

100

200

300

Dimension

(a) µ = 0.5, σ = 1, n1 = 128.

400

0.6 0.4 0.2 0.0

100

200

300

Dimension

(b) µ = 0, σ = 3, n1 = 128.

400

100

200

300

400

Dimension

(c) µ = 0, σ = 3, n1 = 256.

Figure 5: The different MMD tests on the three setups in Figure 2. The legend is consistent across the panels.

15

B.2

Architecture

We have used the same architecture as in [10, 12], which using the modules from PyTorch can be written as follows. nn.Sequential( nn.Linear(noise_dim, 64), nn.ReLU(), nn.Linear(64, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 1024), nn.ReLU(), nn.Linear(1024, ambient_dim)) For MNIST we have also added a terminal nn.Tanh layer.

B.3

Data

We have used the MNIST data as packaged by torchvision, with the additional processing of scaling the output to [−1, 1] as we are using a final Tanh layer. For the two moons data, we have used a noise level of 0.05.

B.4

Optimization

All details are provided in the table below. In some cases we have optimized with a larger step for a number of epochs, and then reduced it for the remaining epochs — in the table below these are separated by commas. Model

Step size

Batch size

Epochs

Figure 3b Figure 3c Figure 3d Figure 4a Figure 4b Figure 4c Figure 4d

10−4 10−4 10−4 10−3 , 10−4 10−3 , 10−4 10−3 , 10−4 10−4 , 10−4

256 256 256 256 512 128 128

500 500 500 500, 500 500, 500 100, 100 100, 100

16

Suggest Documents