In this blogpost, we explore how Wasserstein Generative Adversarial Nets (WGAN) improve upon the minimax game / objective of Generative Adversarial Nets (GAN) to stabilize training and make the value of the game correlate better with the performance of the generator. We first derive the divergence between the real data distribution and the generated one that GANs minimize. Then, we discuss how this divergence is sub-optimal for the optimization of neural networks and introduce the Wasserstein distance, proving that it has better properties w.r.t. neural net optimization. Thereafter, we prove that the Wasserstein distance, although intractable, can be approximated and indeed back-propagated to the generator.
We assume the reader is familiar with basic principles of machine learning, like neural networks and gradient descent, simple probabilistic concepts like, probability density functions, the basic formulation of GANs and Lipschitz functions.
GANs minimize the Jensen-Shannon divergence
Classical approaches to learning a probability distribution assume a parametric family of probability distributions
which can easily be proven to be equal to minimizing the Kullback-Leibler (KL) divergence,
which gives us the opportunity to discuss the maximizing of the likelihood in terms of distribution divergence. When learning using KL divergence, we have to keep four things in mind. First, as can be seen by its definition, KL divergence heavily punishes "mode dropping", i.e. the model assigning little to no probability to areas that the real distribution has "mass". As a matter of fact, it the model assigns zero mass where the real distribution is non-zero, then the divergence becomes infinite. On the other hand, assigning "mass" to fake samples is not heavily punished, as the KL divergence at such a region evaluates to 0. That is important in the light of the fact that KL divergence is non-negative (we prove that in [1]). Thirdly, notice that the KL divergence is not symmetric. Finally, the KL divergence has the nice property of having a single minimum where the distributions coincide.
So, how to GANs handle learning a probability distribution? The intuitive formulation, presented in the original paper of Goodfellow and al. (2014), is that we use a neural net that allows sampling from the distribution based on a helper random variable and train it in minimax game with another helper neural net, the discriminator. This can be formulated as:
where G is the generator and D the discriminator. So, is that all, just alternating training steps between two neural nets? No, as Goodfellow et al. (2014) showed, this is, first, that the optimum discriminator, for a fixed discriminator, is
which can easily be proven using the objective of the minimax game:
and noting that
where JS is the Jensen-Shannon divergence. Therefore, we can see that we learn the generator by minimizing the JS divergence between the real distribution and the generated one (approximately, depending on the optimality of the discriminator). In contrast to KL divergence, the JS divergence is symmetrical. Upon closer inspection, we see that the JS divergence serves as a middle ground between punishing fake samples and mode collapses because it combines two KL divergences with each of the two distributions serving as the first argument.
It is also important to note that in the inaugural GAN paper [2], Goodfellow et al. already propose updating the discriminator multiple times before each generator update, which means that the optimality assumption for the JS divergence is not far from the truth, especially as training proceeds.
Wasserstein distance
We saw that GAN training can be though of as minimizing the JS divergence between the real distribution and the distribution of the generator. Questions naturally arise: what other divergences can we minimize? What are the advantages and the disadvantages of using a particular divergence? To begin with, the most fundamental difference between divergences is their impact on the convergence of sequences of probability distributions (i.e. the parameters of the generator). A sequence of distributions converges iff there exists
Essentially, we would like to have a continuous mapping
As it happens, JS divergence is not optimal for such a setting. It is not continuous when
where
If the function
It is relatively easy to show that a neural network and therefore a generative network satisfies the regularity assumption. However, because the formulas that are required become complicated (it basically suffices to show
Now, we prove that that:
If the generator is continuous in
Let
and because the modeled variable's space is compact and has to be uniformly bounded by some constant along with the fact that
In turn, because the Wasserstein distance is a distance:
which proves the continuity of
By taking expectations and
If we define
for all
To show that these do not hold for the JS divergence, we do so with a counterexample. Let .
This theorem basically shows us that the Wasserstein distance is a more sensible choice of divergence than the JS one for a generator. However, as one can guess, the infimum is highly intractable (see [4]). Thankfully, in the same manner that JS is approximately the objective when the discriminator is optimal for a given generator, we can arrive at similar conclusions about the Wasserstein distance.
Wasserstein GAN
In [4], we can find the description of the Kantorovich-Rubinstein formula, i.e.
where
Notice that the above expression, in particular the right-hand side, when viewed as a function of function
where
where
Brief empirical comparison
GANs are notoriously difficult to train. Experiments utilizing the Wasserstein distance show improved stability in the training of GANs. Moreover, due to the properties of the Wasserstein distance compared to the JS divergence, the objective function of the minimax game that arises in the WGAN framework better correlates with the quality of the generated samples compared to the corresponding value of the mere GAN. Some examples, reproduced from [3], are:
![]() | |
Objective of the minimax game in a GAN (JS divergence). |
![]() |
Objective of the minimax game in a WGAN (Wasserstein distance). |
It is evident that the JS divergence begins to increase as the generator becomes better, in contrast to the Wasserstein distance that either decreases as sample quality increases or remains constant if sample quality does not improve.
Even though WGANs have not gained much traction in image generation, they have become the norm in other settings, such as Zero-shot learning [6].
References
- Chochlakis, G. (2020). The math behind Variational Autoencoders (VAEs), https://thinking-ai-aloud.blogspot.com/2020/10/variational-autoencoders-vae-in-depth.html
- Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2014). Generative adversarial networks. arXiv preprint arXiv:1406.2661.
- Arjovsky, M., Chintala, S., & Bottou, L. (2017, July). Wasserstein generative adversarial networks. In International conference on machine learning (pp. 214-223). PMLR.
- Villani, C. (2008). Optimal transport: old and new (Vol. 338). Springer Science & Business Media.
- Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., & Courville, A. (2017). Improved training of wasserstein gans. arXiv preprint arXiv:1704.00028.
- Chochlakis, G., Georgiou, E., & Potamianos, A. (2021). End-to-end Generative Zero-shot Learning via Few-shot Learning. arXiv preprint arXiv:2102.04379.
Comments
Post a Comment