The Math Behind GANs
Generative Adversarial Networks refer to a family of generative models that seek to discover the underlying distribution behind a certain data generating process. This distribution is discovered through an adversarial competition between a generator and a discriminator. As we saw in an earlier introductory post on GANs, the two models are trained such that the discriminator strives to distinguish between generated and true examples, while the generator seeks to confuse the discriminator by producing data that are as realistic and compelling as possible.
In this post, we’ll take a deep dive into the math behind GANs. My primary source of reference is Generative Adversarial Nets by Ian Goodfellow, et al. It is in this paper that Goodfellow first outlined the concept of a GAN, which is why it only makes sense that we commence from the analysis of this paper. Let’s begin!
Motivating the Loss Function
GAN can be seen as an interplay between two different models: the generator and the discriminator. Therefore, each model will have its own loss function. In this section, let’s try to motivate an intuitive understanding of the loss function for each.
Notation
To minimize confusion, let’s define some notation that we will be using throughout this post.
\[\begin{multline} \shoveleft x: \text{Real data} \\ \shoveleft z: \text{Latent vector} \\ \shoveleft G(z): \text{Fake data} \\ \shoveleft D(x): \text{Discriminator's evaluation of real data} \\ \shoveleft D(G(z)): \text{Discriminator's evaluation of fake data} \\ \shoveleft \text{Error}(a, b): \text{Error between } a \text{ and } b\\ \end{multline}\]The Discriminator
The goal of the discriminator is to correctly label generated images as false and empirical data points as true. Therefore, we might consider the following to be the loss function of the discriminator:
\[L_D = \text{Error}(D(x), 1) + \text{Error}(D(G(z)), 0) \tag{1}\]Here, we are using a very generic, unspecific notation for $\text{Error}$ to refer to some function that tells us the distance or the difference between the two functional parameters. (If this reminded you of something like cross entropy or Kullback-Leibler divergence, you are definitely on the right track.)
The Generator
We can go ahead and do the same for the generator. The goal of the generator is to confuse the discriminator as much as possible such that it mislabels generated images as being true.
\[L_G = \text{Error}(D(G(z)), 1) \tag{2}\]The key here is to remember that a loss function is something that we wish to minimize. In the case of the generator, it should strive to minimize the difference between 1, the label for true data, and the discriminator’s evaluation of the generated fake data.
Binary Cross Entropy
A common loss function that is used in binary classification problems is binary cross entropy. As a quick review, let’s remind ourselves of what the formula for cross entropy looks like:
\[H(p, q) = \mathbb{E}_{x \sim p(x)}[- \log q(x)] \tag{3}\]In classification tasks, the random variable is discrete. Hence, the expectation can be expressed as a summation.
\[H(p, q) = - \sum_{x \in \chi} p(x) \log q(x) \tag{4}\]We can simplify this expression even further in the case of binary cross entropy, since there are only two labels: zero and one.
\[H(y, \hat{y}) = - \sum y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}) \tag{5}\]This is the $\text{Error}$ function that we have been loosely using in the sections above. Binary cross entropy fulfills our objective in that it measures how different two distributions are in the context of binary classification of determining whether an input data point is true or false. Applying this to the loss functions in (1),
\[L_D = - \sum_{x \in \chi, z \in \zeta} \log(D(x)) + \log(1 - D(G(z))) \tag{6}\]We can do the same for (2):
\[L_G = - \sum_{z \in \zeta} \log(D(G(z)) \tag{7}\]Now we have two loss functions with which to train the generator and the discriminator! Note that, for the loss function of the generator, the loss is small if $D(G(z))$ is close to 1, since $\log(1) = 0$. This is exactly the sort of behavior we want from a loss function for the generator. It isn’t difficult to see the cogency of (6) with a similar approach.
Minor Caveats
The original paper by Goodfellow presents a slightly different version of the two loss functions derived above.
\[\max_D \{ \log(D(x)) + \log(1-D(G(z))) \} \tag{8}\]Essentially, the difference between (6) and (8) is the difference in sign, and whether we want to minimize or maximize a given quantity. In (6), we framed the function as a loss function to be minimized, whereas the original formulation presents it as a maximization problem, with the sign obviously flipped.
Then, Goodfellow proceeds by framing (8) as a min-max game, where the discriminator seeks to maximize the given quantity whereas the generator seeks to achieve the reverse. In other words,
\[\min_G \max_D \{ \log(D(x)) + \log(1-D(G(z))) \} \tag{9}\]The min-max formulation is a concise one-liner that intuitively demonstrates the adversarial nature of thecompetition between the generator and the discriminator. However, in practice, we define separate loss functions for the generator and the discriminator as we have done above. This is because the gradient of the function $y = \log x$ is steeper near $x = 0$ than that of the function $y = \log (1 - x)$, meaning that trying to maximize $\log(D(G(z)))$, or equivalently, minimizing $- \log(D(G(z)))$ is going to lead to quicker, more substantial improvements to the performance of the generator than trying to minimize $\log(1 - D(G(z)))$.
Model Optimization
Now that we have defined the loss functions for the generator and the discriminator, it’s time to leverage some math to solve the optimization problem, i.e. finding the parameters for the generator and the discriminator such that the loss functions are optimized. This corresponds to training the model in practical terms.
Training the Discriminator
When training a GAN, we typically train one model at a time. In other words, when training the discriminator, the generator is assumed as fixed. We saw this in action in the previous post on how to build a basic GAN.
Let’s return back to the min-max game. The quantity of interest can be defined as a function of $G$ and $D$. Let’s call this the value function:
\[V(G, D) = \mathbb{E}_{x \sim p_{data}}[\log(D(x))] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))] \tag{10}\]In reality, we are more interested in the distribution modeled by the generator than $p_z$. Therefore, let’s create a new variable, $y = G(z)$, and use this substitution to rewrite the value function:
\[\begin{align} V(G, D) &= \mathbb{E}_{x \sim p_{data}}[\log(D(x))] + \mathbb{E}_{y \sim p_g}[\log(1 - D(y))] \\ &= \int_{x \in \chi} p_{data}(x) \log(D(x)) + p_g(x) \log(1 - D(x)) \, dx \end{align} \tag{11}\]The goal of the discriminator is to maximize this value function. Through a partial derivative of $V(G, D)$ with respect to $D(x)$, we see that the optimal discriminator, denoted as $D^*(x)$, occurs when
\[\frac{p_{data}(x)}{D(x)} - \frac{p_g(x)}{1 - D(x)} = 0 \tag{12}\]Rearranging (12), we get
\[D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} \tag{12}\]And this is the condition for the optimal discriminator! Note that the formula makes intuitive sense: if some sample $x$ is highly genuine, we would expect $p_{data}(x)$ to be close to one and $p_g(x)$ to be converge to zero, in which case the optimal discriminator would assign 1 to that sample. On the other hand, for a generated sample $x = G(z)$, we expect the optimal discriminator to assign a label of zero, since $p_{data}(G(z))$ should be close to zero.
Training the Generator
To train the generator, we assume the discriminator to be fixed and proceed with the analysis of the value function. Let’s first plug in the result we found above, namely (12), into the value function to see what turns out.
\[\begin{align} V(G, D^*) &= \mathbb{E}_{x \sim p_{data}}[\log(D^*(x))] + \mathbb{E}_{x \sim p_g}[\log(1 - D^*(x))] \\ &= \mathbb{E}_{x \sim p_{data}} \left[ \log \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} \right] + \mathbb{E}_{x \sim p_g} \left[ \log \frac{p_g(x)}{p_{data}(x) + p_g(x)} \right] \end{align} \tag{13}\]To proceed from here, we need a little bit of inspiration. Little clever tricks like these are always a joy to look at.
\[\begin{align} V(G, D^*) &= \mathbb{E}_{x \sim p_{data}} \left[ \log \frac{p_{data}(x)}{p_{data}(x) + p_g(x)} \right] + \mathbb{E}_{x \sim p_g} \left[ \log \frac{p_g(x)}{p_{data}(x) + p_g(x)} \right] \\ &= - \log 4 + \mathbb{E}_{x \sim p_{data}} \left[ \log p_{data}(x) - \log \frac{p_{data}(x) + p_g(x))}{2} \right] \\ & \quad+ \mathbb{E}_{x \sim p_g} \left[ \log p_g(x) - \log\frac{p_{data}(x) + p_g(x))}{2} \right] \end{align} \tag{14}\]If you are confused, don’t worry, you aren’t the only one. Basically, what is happening is that we are exploiting the properties of logarithms to pull out a $- \log4$ that previously did not exist. In pulling out this number, we inevitably apply changes to the terms in the expectation, specifically by dividing the denominator by two.
Why was this necessary? The magic here is that we can now interpret the expectations as Kullback-Leibler divergence:
\[V(G, D^*) = - \log 4 + D_{KL}\left(p_{data} \parallel \frac{p_{data} + p_g}{2} \right) + D_{KL}\left(p_g \parallel \frac{p_g + p_g}{2} \right) \tag{15}\]And it is here that we reencounter the Jensen-Shannon divergence, which is defined as
\[J(P,Q) = \frac{1}{2} \left( D(P \parallel R) + D(Q \parallel R) \right) \tag{16}\]where $R = \frac12(P + Q)$. This means that the expression in (15) can be expressed as a JS divergence:
\[V(G, D^*) = - \log 4 + 2 \cdot D_{JS}(p_{data} \parallel p_g) \tag{15}\]The conclusion of this analysis is simple: the goal of training the generator, which is to minimize the value function $V(G, D)$, we want the JS divergence between the distribution of the data and the distribution of generated examples to be as small as possible. This conclusion certainly aligns with our intuition: we want the generator to be able to learn the underlying distribution of the data from sampled training examples. In other words, $p_g$ and $p_{data}$ should be as close to each other as possible. The optimal generator $G$ is thus one that which is able to mimic $p_{data}$ to model a compelling model distribution $p_g$.
Conclusion
In this post, we took a brief tour of the math behind general adversarial networks. Since the publication of Goodfellow’s work, more GAN models have been introduced and studied by different scholars, such as the Wasserstein GAN or CycleGAN to name just a few. The underlying mathematics for these models are obviously going to be different from what we have seen today, but this is a good starting point nonetheless.
I hope you enjoyed reading this post. In the next post, I plan to explore the concept of Fisher information and the Fisher matrix. It is going to be another math-heavy ride with gradients and Hessians, so keep you belts fastened!