Jekyll2024-02-06T04:03:03+00:00https://jaketae.github.io/feed.xmlJake TaeHey there! My name is Jake, and I'm a sophomore at Yale University.Jake TaeLoRA2023-08-20T00:00:00+00:002023-08-20T00:00:00+00:00https://jaketae.github.io/study/lora<blockquote>
<p>I recently completed another summer internship at Meta (formerly Facebook). I was surprised to learn that one of the intern friends I met was an avid reader of my blog. Encouraged by the positive feedback from my intern friends, I decided to write another post before the end of summer. This post is dedicated to the mandem: Yassir, Amal, Ryan, Elvis, and Sam.</p>
</blockquote>
<p>Today, we will take a look at <a href="https://arxiv.org/abs/2106.09685">LoRA: Low-Rank Adaptation of Large Language Models</a> by Hu et al. Alongside <a href="https://github.com/TimDettmers/bitsandbytes">bitsandbytes</a>, LoRA has been a key ingredient in democratizing language models like <a href="https://ai.meta.com/llama/">Llama</a><sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>, making them available for both inference and fine-tuning on consumer-grade GPUs. In particular, LoRA is a prerequisite to understand <a href="https://arxiv.org/abs/2305.14314">QLoRA</a>, which combines int4 quantization with low-rank adaptation.</p>
<p>This post was heavily inspired by other great resources on LoRA:</p>
<ol>
<li>sunildkumar’s <a href="https://github.com/sunildkumar/lora_from_scratch/tree/main">lora_from_scratch</a></li>
<li>Hugging Face’s <a href="https://huggingface.co/docs/peft/conceptual_guides/lora">LoRA Concept Guide</a></li>
<li>Chris Alexiuk’s <a href="https://www.youtube.com/watch?v=dA-NhCtrrVE&t=767s">YouTube Video</a></li>
</ol>
<p>Let’s get into it!</p>
<h1 id="recap">Recap</h1>
<p>In this section, we will quickly recap some concepts in linear algebra, which will help us understand LoRA.</p>
<h2 id="rank">Rank</h2>
<p>In linear algebra, <a href="https://en.wikipedia.org/wiki/Rank_(linear_algebra)">rank</a> denotes the dimension of the row and column space of the matrix. In other words, it is the number of linearly independent row or column vectors of the matrix. Another handy little fact about rank is that only full-rank matrices are invertible. To recall:</p>
<ol>
<li>Row rank equals column rank;</li>
<li>A matrix is invertible if and only if it is full-rank;</li>
<li>A matrix is full-rank if and only if its determinant is non-zero.</li>
</ol>
<p>Without getting into too much detail, the rough proof sketch for these propositions involves using reduction operations to produce diagonal matrices, and using other elementary facts about invertibility and determinants.</p>
<h2 id="decomposition">Decomposition</h2>
<p>For the purposes of understanding LoRA, it suffices to intuit rank as the amount of information encoded into a matrix from the perspective of decomposition. Concretely, consider $A$, a 4 x 4 matrix.</p>
\[A = \begin{pmatrix}
1 & 3 & 1 & 4 \\
2 & 7 & 3 & 9 \\
1 & 5 & 3 & 1 \\
1 & 2 & 0 & 8 \\
\end{pmatrix}\]
<p>In reduced echelon form, we obtain $B$:</p>
\[B = \begin{pmatrix}
1 & 0 & -2 & 0 \\
0 & 1 & 1 & 0 \\
0 & 0 & 0 & 1 \\
0 & 0 & 0 & 0 \\
\end{pmatrix}\]
<p>Therefore, it is clear that $A$ is a rank 3 matrix. Then the claim is that $A$ can be decomposed into two matrices of size (4, 3) and (3, 4). Indeed, we have</p>
\[\begin{pmatrix}
1 & 3 & 4 \\
2 & 7 & 9 \\
1 & 5 & 1 \\
1 & 2 & 8 \\
\end{pmatrix}
\begin{pmatrix}
1 & 0 & -2 & 0 \\
0 & 1 & 1 & 0 \\
0 & 0 & 0 & 1 \\
\end{pmatrix}
=
\begin{pmatrix}
1 & 3 & 1 & 4 \\
2 & 7 & 3 & 9 \\
1 & 5 & 3 & 1 \\
1 & 2 & 0 & 8 \\
\end{pmatrix}\]
<p>In other words, rank determines the structure of matrix decomposition. In this example, $A$ was relatively closer to being full-rank: it was a rank 3 matrix, and the maximal rank it could have was 4. However, we can also imagine decomposing matrices with smaller rank, e.g., an $n \times n$ matrix being decomposed into $(n, m)$ and $(m, n)$, where $m \ll n$. Indeed, this is the key behind LoRA.</p>
<h1 id="lora">LoRA</h1>
<p>Large language models (LLMs) have become very large in recent years. Even the smaller standard LLMs, such as Llama 7B, have billions of parameters by default. Finetuning such large models for individual tasks is prohibitively expensive to say the least.</p>
<p>LoRA is a parameter-efficient training methodology. In short, instead of training the entire model parameters, LoRA proposes to train a few extra parameters to fuse with the activations of the original model. Let’s see what this exactly means.</p>
<h2 id="low-rank">Low Rank</h2>
<p>LoRA starts from the simple hypothesis that</p>
<blockquote>
<p>the change in weights during model adaptation … has a low “intrinsic rank”[.]</p>
</blockquote>
<p>In other words, the authors of LoRA hypothesize that the delta shift in model weights during training is actually a low rank matrix. If this is true, we should be able to emulate the effects of full finetuning by simply training two small low rank matrices. This is precisely what LoRA does.</p>
<blockquote>
<p>LoRA allows us to train some dense layers in a neural network indirectly by optimizing rank decomposition matrices of the dense layers’ change during adaptation instead, while keeping the pre-trained weights frozen[.]</p>
</blockquote>
<figure>
<img src="https://global-uploads.webflow.com/63f3993d10c2a062a4c9f13c/64649977d084d2b4b66c6492_1*e5pYWjrZR3eA_YbCKu8deQ.png" /> <figcaption>
Overview of LoRA. Image from ML6.
</figcaption>
</figure>
<p>In the diagram above, $W_\text{nk}$ is the full pretrained weight matrix. Instead of trying to finetune $W_\text{nk}$ in its entirety, LoRA instead adds to auxiliary matrices, $A$ and $B$, which have rank $r$. Let $W_\text{nk} \in \mathbb{R}^{n \times k}$. Then $A \in \mathbb{R}^{n times r}, B \in \mathbb{R}^{r \times k}$. If $r$ is small enough, then only training $A$ and $B$ will be much cheaper than training $W_\text{nk}$, i.e.,</p>
\[nk > r (n + k).\]
<p>It is easy to see that when $r = \max{n, k}$, then we recover the full finetuning setup. Therefore, LoRA can be seen as a generalization of full finetuning.</p>
<p>The authors limit the study of LoRA to the transformers architecture, testing it on a wide range of encoder and decoder models such as RoBERTa, DeBERTa, and GPT-3. They apply LoRA on the weight matrices of the self-attention module.</p>
<h2 id="inference">Inference</h2>
<p>The forward pass of a LoRA model can be written as</p>
\[h = W_\text{nk} X + AB X\]
<p>This is obviously different from the original unmodified forward pass, which would be</p>
\[h = W_\text{nk} X.\]
<p>We could maintain separate modules for the original weight matrix $W_\text{nk}$ and $A, B$. However, after training is complete, we can speed up the forward pass by fusing the modules to reduce FLOPs.</p>
\[\begin{align*}
W'_\text{nk} &= W + AB \\
h &= W'_\text{nk} X.
\end{align*}\]
<p>In other words, instead of maintaining the two-branch structure, we simply merge the LoRA delta matrix into the original frozen pretrained weight.</p>
<h1 id="implementation">Implementation</h1>
<p>Now that we have an idea of how LoRA works, let’s try a simple implementation with <a href="https://lightning.ai/pytorch-lightning">PyTorch Lightning</a>. This implementation was heavily inspired by sunildkumar’s <a href="https://github.com/sunildkumar/lora_from_scratch/tree/main">lora_from_scratch</a>.</p>
<h2 id="setup">Setup</h2>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">%%</span><span class="n">capture</span>
<span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="n">lightning</span>
</code></pre></div></div>
<p>We import necessary dependencies and set the seed for reproducability.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">optim</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span> <span class="nn">torch.utils.data</span> <span class="kn">import</span> <span class="n">random_split</span><span class="p">,</span> <span class="n">DataLoader</span>
<span class="kn">from</span> <span class="nn">torchvision.datasets</span> <span class="kn">import</span> <span class="n">MNIST</span>
<span class="kn">from</span> <span class="nn">torchvision</span> <span class="kn">import</span> <span class="n">transforms</span>
<span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">Accuracy</span>
<span class="kn">import</span> <span class="nn">lightning.pytorch</span> <span class="k">as</span> <span class="n">pl</span>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">import</span> <span class="nn">seaborn</span> <span class="k">as</span> <span class="n">sns</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span>
<span class="o">%</span><span class="n">config</span> <span class="n">InlineBackend</span><span class="p">.</span><span class="n">figure_format</span><span class="o">=</span><span class="s">'retina'</span>
<span class="n">pl</span><span class="p">.</span><span class="n">seed_everything</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>INFO: Global seed set to 42
INFO:lightning.fabric.utilities.seed:Global seed set to 42
</code></pre></div></div>
<p>We will be using the MNIST toy dataset. PyTorch Lightning provides a convenient data module API, where we can pack all logic related to the data into a single class.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MNISTDataModule</span><span class="p">(</span><span class="n">pl</span><span class="p">.</span><span class="n">LightningDataModule</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_dir</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s">"."</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1024</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">data_dir</span> <span class="o">=</span> <span class="n">data_dir</span>
<span class="bp">self</span><span class="p">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span>
<span class="bp">self</span><span class="p">.</span><span class="n">transform</span> <span class="o">=</span> <span class="n">transforms</span><span class="p">.</span><span class="n">ToTensor</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">stage</span><span class="p">:</span> <span class="nb">str</span><span class="p">):</span>
<span class="k">if</span> <span class="n">stage</span> <span class="o">==</span> <span class="s">"fit"</span><span class="p">:</span>
<span class="n">mnist_full</span> <span class="o">=</span> <span class="n">MNIST</span><span class="p">(</span>
<span class="bp">self</span><span class="p">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">transform</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">mnist_train</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">mnist_val</span> <span class="o">=</span> <span class="n">random_split</span><span class="p">(</span><span class="n">mnist_full</span><span class="p">,</span> <span class="p">[</span><span class="mi">55000</span><span class="p">,</span> <span class="mi">5000</span><span class="p">])</span>
<span class="k">elif</span> <span class="n">stage</span> <span class="o">==</span> <span class="s">"test"</span><span class="p">:</span>
<span class="bp">self</span><span class="p">.</span><span class="n">mnist_test</span> <span class="o">=</span> <span class="n">MNIST</span><span class="p">(</span>
<span class="bp">self</span><span class="p">.</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">transform</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">train_dataloader</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">mnist_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">val_dataloader</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">mnist_val</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">test_dataloader</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">mnist_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">batch_size</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">predict_dataloader</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">mnist_predict</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">batch_size</span><span class="p">)</span>
</code></pre></div></div>
<p>In this dummy experiment, we will train the model on MNIST for 5 epochs. The baseline will then be continued to train on the dataset for 5 more epochs to emulate the effects of full “finetuning.” The LoRA model will be initialized from the 5-epoch checkpoint, then trained for another 5 epochs.</p>
<h2 id="pretraining">Pretraining</h2>
<p>We will be training a simple dense model.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MNISTModel</span><span class="p">(</span><span class="n">pl</span><span class="p">.</span><span class="n">LightningModule</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">2e-4</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">lr</span> <span class="o">=</span> <span class="n">lr</span>
<span class="n">num_classes</span> <span class="o">=</span> <span class="mi">10</span>
<span class="bp">self</span><span class="p">.</span><span class="n">l1</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">28</span> <span class="o">*</span> <span class="mi">28</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">l2</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">l3</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">val_accuracy</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">(</span>
<span class="n">task</span><span class="o">=</span><span class="s">"multiclass"</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">test_accuracy</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">(</span>
<span class="n">task</span><span class="o">=</span><span class="s">"multiclass"</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">num_classes</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">start_dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">F</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">l1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">F</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">l2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">l3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
<span class="k">def</span> <span class="nf">configure_optimizers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">lr</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">base_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">batch</span>
<span class="n">logits</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span>
<span class="k">def</span> <span class="nf">training_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
<span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">base_step</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="s">"train_loss"</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
<span class="k">return</span> <span class="n">loss</span>
<span class="k">def</span> <span class="nf">validation_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">base_step</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
<span class="n">preds</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">val_accuracy</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="s">"val_loss"</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="s">"val_acc"</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">val_accuracy</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">test_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">base_step</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">batch_idx</span><span class="p">)</span>
<span class="n">preds</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">val_accuracy</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="s">"test_loss"</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="s">"test_acc"</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">val_accuracy</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">MNISTModel</span><span class="p">()</span>
<span class="n">datamodule</span> <span class="o">=</span> <span class="n">MNISTDataModule</span><span class="p">()</span>
<span class="n">pretrainer</span> <span class="o">=</span> <span class="n">pl</span><span class="p">.</span><span class="n">Trainer</span><span class="p">(</span>
<span class="n">accelerator</span><span class="o">=</span><span class="s">"auto"</span><span class="p">,</span>
<span class="n">devices</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">max_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="n">logger</span><span class="o">=</span><span class="n">pl</span><span class="p">.</span><span class="n">loggers</span><span class="p">.</span><span class="n">CSVLogger</span><span class="p">(</span><span class="s">"logs"</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">pretrainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)</span>
<span class="n">pretrainer</span><span class="p">.</span><span class="n">test</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:
| Name | Type | Params
-----------------------------------------------------
0 | l1 | Linear | 50.2 K
1 | l2 | Linear | 4.2 K
2 | l3 | Linear | 650
3 | val_accuracy | MulticlassAccuracy | 0
4 | test_accuracy | MulticlassAccuracy | 0
-----------------------------------------------------
55.1 K Trainable params
0 Non-trainable params
55.1 K Total params
0.220 Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
| Name | Type | Params
-----------------------------------------------------
0 | l1 | Linear | 50.2 K
1 | l2 | Linear | 4.2 K
2 | l3 | Linear | 650
3 | val_accuracy | MulticlassAccuracy | 0
4 | test_accuracy | MulticlassAccuracy | 0
-----------------------------------------------------
55.1 K Trainable params
0 Non-trainable params
55.1 K Total params
0.220 Total estimated model params size (MB)
INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing: 0it [00:00, ?it/s]
</code></pre></div></div>
<div style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace" class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test_acc │ 0.7035999894142151 │
│ test_loss │ 0.9380418062210083 │
└───────────────────────────┴───────────────────────────┘
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[{'test_loss': 0.9380418062210083, 'test_acc': 0.7035999894142151}]
</code></pre></div></div>
<p>Let’s read the metrics.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">read_metrics</span><span class="p">(</span><span class="n">path</span><span class="p">):</span>
<span class="n">metrics</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">path</span><span class="p">)</span>
<span class="k">del</span> <span class="n">metrics</span><span class="p">[</span><span class="s">"step"</span><span class="p">]</span>
<span class="n">metrics</span><span class="p">.</span><span class="n">set_index</span><span class="p">(</span><span class="s">"epoch"</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">display</span><span class="p">(</span><span class="n">metrics</span><span class="p">.</span><span class="n">dropna</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">how</span><span class="o">=</span><span class="s">"all"</span><span class="p">).</span><span class="n">head</span><span class="p">())</span>
<span class="n">sns</span><span class="p">.</span><span class="n">relplot</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">metrics</span><span class="p">,</span> <span class="n">kind</span><span class="o">=</span><span class="s">"line"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
<span class="n">read_metrics</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">pretrainer</span><span class="p">.</span><span class="n">logger</span><span class="p">.</span><span class="n">log_dir</span><span class="si">}</span><span class="s">/metrics.csv"</span><span class="p">)</span>
</code></pre></div></div>
<div id="df-58544aec-ab62-4863-afca-1870744cc541" class="colab-df-container">
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>train_loss</th>
<th>val_loss</th>
<th>val_acc</th>
<th>test_loss</th>
<th>test_acc</th>
</tr>
<tr>
<th>epoch</th>
<th></th>
<th></th>
<th></th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>2.113543</td>
<td>NaN</td>
<td>NaN</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>0</th>
<td>NaN</td>
<td>2.083865</td>
<td>0.3740</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>1</th>
<td>1.708588</td>
<td>NaN</td>
<td>NaN</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>1</th>
<td>NaN</td>
<td>1.654867</td>
<td>0.5168</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>2</th>
<td>1.350459</td>
<td>NaN</td>
<td>NaN</td>
<td>NaN</td>
<td>NaN</td>
</tr>
</tbody>
</table>
</div>
<div class="colab-df-buttons">
<div class="colab-df-container">
<button class="colab-df-convert" onclick="convertToInteractive('df-58544aec-ab62-4863-afca-1870744cc541')" title="Convert this dataframe to an interactive table." style="display:none;"> <svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960"> <path d="M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z"></path> </svg>
</button> <style>
.colab-df-container {
display:flex;
gap: 12px;
}
.colab-df-convert {
background-color: #E8F0FE;
border: none;
border-radius: 50%;
cursor: pointer;
display: none;
fill: #1967D2;
height: 32px;
padding: 0 0 0 0;
width: 32px;
}
.colab-df-convert:hover {
background-color: #E2EBFA;
box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);
fill: #174EA6;
}
.colab-df-buttons div {
margin-bottom: 4px;
}
[theme=dark] .colab-df-convert {
background-color: #3B4455;
fill: #D2E3FC;
}
[theme=dark] .colab-df-convert:hover {
background-color: #434B5C;
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);
filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));
fill: #FFFFFF;
}
</style>
<script>
const buttonEl =
document.querySelector('#df-58544aec-ab62-4863-afca-1870744cc541 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-58544aec-ab62-4863-afca-1870744cc541');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
<div id="df-154dc17f-b416-4f40-a885-07ba6fcb4447">
<button class="colab-df-quickchart" onclick="quickchart('df-154dc17f-b416-4f40-a885-07ba6fcb4447')" title="Suggest charts." style="display:none;"> <svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24" width="24px"> <g> <path d="M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z"></path> </g>
</svg> </button> <style>
.colab-df-quickchart {
background-color: #E8F0FE;
border: none;
border-radius: 50%;
cursor: pointer;
display: none;
fill: #1967D2;
height: 32px;
padding: 0 0 0 0;
width: 32px;
}
.colab-df-quickchart:hover {
background-color: #E2EBFA;
box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);
fill: #174EA6;
}
[theme=dark] .colab-df-quickchart {
background-color: #3B4455;
fill: #D2E3FC;
}
[theme=dark] .colab-df-quickchart:hover {
background-color: #434B5C;
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);
filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));
fill: #FFFFFF;
}
</style>
<script>
async function quickchart(key) {
const charts = await google.colab.kernel.invokeFunction(
'suggestCharts', [key], {});
}
(() => {
let quickchartButtonEl =
document.querySelector('#df-154dc17f-b416-4f40-a885-07ba6fcb4447 button');
quickchartButtonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
})();
</script>
</div>
</div>
</div>
<p><img src="/assets/images/2023-08-20-lora_files/2023-08-20-lora_14_1.png" /></p>
<p>We save the model using both the PyTorch Lightning trainer API as well as the default PyTorch API. We will use the former to contiue training the model to simualte full finetuning and the latter to initialize the LoRA model from the trained checkpoint.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">pretrainer</span><span class="p">.</span><span class="n">save_checkpoint</span><span class="p">(</span><span class="s">"model.ckpt"</span><span class="p">)</span>
<span class="n">torch</span><span class="p">.</span><span class="n">save</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s">'model.pt'</span><span class="p">)</span>
</code></pre></div></div>
<h2 id="baseline">Baseline</h2>
<p>Let’s continue training the model for 5 more epochs to see how it improves. This is the full finetuning baseline.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span> <span class="o">=</span> <span class="n">MNISTModel</span><span class="p">.</span><span class="n">load_from_checkpoint</span><span class="p">(</span><span class="s">"model.ckpt"</span><span class="p">)</span>
<span class="n">trainer</span> <span class="o">=</span> <span class="n">pl</span><span class="p">.</span><span class="n">Trainer</span><span class="p">(</span>
<span class="n">accelerator</span><span class="o">=</span><span class="s">"auto"</span><span class="p">,</span>
<span class="n">devices</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">max_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="n">logger</span><span class="o">=</span><span class="n">pl</span><span class="p">.</span><span class="n">loggers</span><span class="p">.</span><span class="n">CSVLogger</span><span class="p">(</span><span class="s">"logs"</span><span class="p">)</span>
<span class="p">)</span>
<span class="n">trainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)</span>
<span class="n">read_metrics</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">trainer</span><span class="p">.</span><span class="n">logger</span><span class="p">.</span><span class="n">log_dir</span><span class="si">}</span><span class="s">/metrics.csv"</span><span class="p">)</span>
<span class="n">trainer</span><span class="p">.</span><span class="n">test</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:
| Name | Type | Params
-----------------------------------------------------
0 | l1 | Linear | 50.2 K
1 | l2 | Linear | 4.2 K
2 | l3 | Linear | 650
3 | val_accuracy | MulticlassAccuracy | 0
4 | test_accuracy | MulticlassAccuracy | 0
-----------------------------------------------------
55.1 K Trainable params
0 Non-trainable params
55.1 K Total params
0.220 Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
| Name | Type | Params
-----------------------------------------------------
0 | l1 | Linear | 50.2 K
1 | l2 | Linear | 4.2 K
2 | l3 | Linear | 650
3 | val_accuracy | MulticlassAccuracy | 0
4 | test_accuracy | MulticlassAccuracy | 0
-----------------------------------------------------
55.1 K Trainable params
0 Non-trainable params
55.1 K Total params
0.220 Total estimated model params size (MB)
INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
</code></pre></div></div>
<div id="df-22333804-133d-4304-9922-1c3d99c2dbc7" class="colab-df-container">
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>train_loss</th>
<th>val_loss</th>
<th>val_acc</th>
</tr>
<tr>
<th>epoch</th>
<th></th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>0.920625</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>0</th>
<td>NaN</td>
<td>0.848711</td>
<td>0.7360</td>
</tr>
<tr>
<th>1</th>
<td>0.804315</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>1</th>
<td>NaN</td>
<td>0.757564</td>
<td>0.7658</td>
</tr>
<tr>
<th>2</th>
<td>0.751054</td>
<td>NaN</td>
<td>NaN</td>
</tr>
</tbody>
</table>
</div>
<div class="colab-df-buttons">
<div class="colab-df-container">
<button class="colab-df-convert" onclick="convertToInteractive('df-22333804-133d-4304-9922-1c3d99c2dbc7')" title="Convert this dataframe to an interactive table." style="display:none;"> <svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960"> <path d="M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z"></path> </svg>
</button> <style>
.colab-df-container {
display:flex;
gap: 12px;
}
.colab-df-convert {
background-color: #E8F0FE;
border: none;
border-radius: 50%;
cursor: pointer;
display: none;
fill: #1967D2;
height: 32px;
padding: 0 0 0 0;
width: 32px;
}
.colab-df-convert:hover {
background-color: #E2EBFA;
box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);
fill: #174EA6;
}
.colab-df-buttons div {
margin-bottom: 4px;
}
[theme=dark] .colab-df-convert {
background-color: #3B4455;
fill: #D2E3FC;
}
[theme=dark] .colab-df-convert:hover {
background-color: #434B5C;
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);
filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));
fill: #FFFFFF;
}
</style>
<script>
const buttonEl =
document.querySelector('#df-22333804-133d-4304-9922-1c3d99c2dbc7 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-22333804-133d-4304-9922-1c3d99c2dbc7');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
<div id="df-94722554-a60c-4397-9d41-44e6429e762f">
<button class="colab-df-quickchart" onclick="quickchart('df-94722554-a60c-4397-9d41-44e6429e762f')" title="Suggest charts." style="display:none;"> <svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24" width="24px"> <g> <path d="M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z"></path> </g>
</svg> </button> <style>
.colab-df-quickchart {
background-color: #E8F0FE;
border: none;
border-radius: 50%;
cursor: pointer;
display: none;
fill: #1967D2;
height: 32px;
padding: 0 0 0 0;
width: 32px;
}
.colab-df-quickchart:hover {
background-color: #E2EBFA;
box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);
fill: #174EA6;
}
[theme=dark] .colab-df-quickchart {
background-color: #3B4455;
fill: #D2E3FC;
}
[theme=dark] .colab-df-quickchart:hover {
background-color: #434B5C;
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);
filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));
fill: #FFFFFF;
}
</style>
<script>
async function quickchart(key) {
const charts = await google.colab.kernel.invokeFunction(
'suggestCharts', [key], {});
}
(() => {
let quickchartButtonEl =
document.querySelector('#df-94722554-a60c-4397-9d41-44e6429e762f button');
quickchartButtonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
})();
</script>
</div>
</div>
</div>
<p><img src="/assets/images/2023-08-20-lora_files/2023-08-20-lora_18_10.png" /></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing: 0it [00:00, ?it/s]
</code></pre></div></div>
<div style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace" class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test_acc │ 0.809499979019165 │
│ test_loss │ 0.6315763592720032 │
└───────────────────────────┴───────────────────────────┘
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[{'test_loss': 0.6315763592720032, 'test_acc': 0.809499979019165}]
</code></pre></div></div>
<p>We see that the test accuracy improved from the previous 0.7 to around 0.81, as expected.</p>
<h2 id="lora-1">LoRA</h2>
<p>Next, we create a new LoRA model. To build the LoRA model, we will create a simple <code class="language-plaintext highlighter-rouge">LoRALinear</code> class that abstracts away the initialization and forward pass through the two low rank matrices.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">math</span>
<span class="k">class</span> <span class="nc">LoRALinear</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="p">.</span><span class="n">A</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">empty</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">rank</span><span class="p">))</span>
<span class="bp">self</span><span class="p">.</span><span class="n">B</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">empty</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">out_features</span><span class="p">))</span>
<span class="n">nn</span><span class="p">.</span><span class="n">init</span><span class="p">.</span><span class="n">kaiming_uniform_</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">A</span><span class="p">,</span> <span class="n">a</span><span class="o">=</span><span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">5</span><span class="p">))</span>
<span class="n">nn</span><span class="p">.</span><span class="n">init</span><span class="p">.</span><span class="n">zeros_</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">B</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">@</span> <span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">A</span> <span class="o">@</span> <span class="bp">self</span><span class="p">.</span><span class="n">B</span><span class="p">)</span>
</code></pre></div></div>
<p>The LoRA model inherits from the <code class="language-plaintext highlighter-rouge">MNISTModel</code>. We perform two steps:</p>
<ol>
<li>Freeze the already trained parameters from <code class="language-plaintext highlighter-rouge">MNISTModel</code>;</li>
<li>Initialize the low rank matrices via <code class="language-plaintext highlighter-rouge">LoRALinear</code> with the specified <code class="language-plaintext highlighter-rouge">rank</code>.</li>
</ol>
<p>During the forward process, we use an <code class="language-plaintext highlighter-rouge">alpha</code> parameter to determine how much mixing we want to perform between the activations from LoRA and the frozen pretrained model.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">MNISTLoRAModel</span><span class="p">(</span><span class="n">MNISTModel</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">alpha</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">parameter</span> <span class="ow">in</span> <span class="bp">self</span><span class="p">.</span><span class="n">named_parameters</span><span class="p">():</span>
<span class="n">parameter</span><span class="p">.</span><span class="n">requires_grad</span> <span class="o">=</span> <span class="bp">False</span>
<span class="bp">self</span><span class="p">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">rank</span>
<span class="bp">self</span><span class="p">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span>
<span class="bp">self</span><span class="p">.</span><span class="n">l1_lora</span> <span class="o">=</span> <span class="n">LoRALinear</span><span class="p">(</span><span class="mi">28</span> <span class="o">*</span> <span class="mi">28</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">l2_lora</span> <span class="o">=</span> <span class="n">LoRALinear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">l3_lora</span> <span class="o">=</span> <span class="n">LoRALinear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">rank</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">start_dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">F</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">l1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">alpha</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">l1_lora</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">F</span><span class="p">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">l2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">alpha</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">l2_lora</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">l3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="p">.</span><span class="n">alpha</span> <span class="o">*</span> <span class="bp">self</span><span class="p">.</span><span class="n">l3_lora</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
<span class="k">def</span> <span class="nf">configure_optimizers</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="nb">super</span><span class="p">().</span><span class="n">configure_optimizers</span><span class="p">()</span>
<span class="n">scheduler</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">lr_scheduler</span><span class="p">.</span><span class="n">ReduceLROnPlateau</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="s">"min"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="k">return</span> <span class="p">{</span>
<span class="s">"optimizer"</span><span class="p">:</span> <span class="n">optimizer</span><span class="p">,</span>
<span class="s">"lr_scheduler"</span><span class="p">:</span> <span class="p">{</span>
<span class="s">"scheduler"</span><span class="p">:</span> <span class="n">scheduler</span><span class="p">,</span>
<span class="s">"monitor"</span><span class="p">:</span> <span class="s">"val_loss"</span><span class="p">,</span>
<span class="s">"frequency"</span><span class="p">:</span> <span class="mi">1</span>
<span class="p">},</span>
<span class="p">}</span>
</code></pre></div></div>
<p>Here, we set <code class="language-plaintext highlighter-rouge">rank</code> to 32 and <code class="language-plaintext highlighter-rouge">alpha</code> to 1. Let’s try training the model for 5 additional epochs, just like the baseline. Note that with this LoRA configuration, we are training around 33K parameters, which is smaller than the full finetuning baseline of 55K parameters.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lora_model</span> <span class="o">=</span> <span class="n">MNISTLoRAModel</span><span class="p">(</span><span class="n">rank</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"model.pt"</span><span class="p">)</span>
<span class="n">lora_model</span><span class="p">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">strict</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">datamodule</span> <span class="o">=</span> <span class="n">MNISTDataModule</span><span class="p">()</span>
<span class="n">lora_trainer</span> <span class="o">=</span> <span class="n">pl</span><span class="p">.</span><span class="n">Trainer</span><span class="p">(</span>
<span class="n">accelerator</span><span class="o">=</span><span class="s">"auto"</span><span class="p">,</span>
<span class="n">devices</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">max_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="n">logger</span><span class="o">=</span><span class="n">pl</span><span class="p">.</span><span class="n">loggers</span><span class="p">.</span><span class="n">CSVLogger</span><span class="p">(</span><span class="s">"logs"</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">lora_trainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">lora_model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)</span>
<span class="n">read_metrics</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">lora_trainer</span><span class="p">.</span><span class="n">logger</span><span class="p">.</span><span class="n">log_dir</span><span class="si">}</span><span class="s">/metrics.csv"</span><span class="p">)</span>
<span class="n">lora_trainer</span><span class="p">.</span><span class="n">test</span><span class="p">(</span><span class="n">lora_model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:
| Name | Type | Params
-----------------------------------------------------
0 | l1 | Linear | 50.2 K
1 | l2 | Linear | 4.2 K
2 | l3 | Linear | 650
3 | val_accuracy | MulticlassAccuracy | 0
4 | test_accuracy | MulticlassAccuracy | 0
5 | l1_lora | LoRALinear | 27.1 K
6 | l2_lora | LoRALinear | 4.1 K
7 | l3_lora | LoRALinear | 2.4 K
-----------------------------------------------------
33.6 K Trainable params
55.1 K Non-trainable params
88.7 K Total params
0.355 Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
| Name | Type | Params
-----------------------------------------------------
0 | l1 | Linear | 50.2 K
1 | l2 | Linear | 4.2 K
2 | l3 | Linear | 650
3 | val_accuracy | MulticlassAccuracy | 0
4 | test_accuracy | MulticlassAccuracy | 0
5 | l1_lora | LoRALinear | 27.1 K
6 | l2_lora | LoRALinear | 4.1 K
7 | l3_lora | LoRALinear | 2.4 K
-----------------------------------------------------
33.6 K Trainable params
55.1 K Non-trainable params
88.7 K Total params
0.355 Total estimated model params size (MB)
INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
</code></pre></div></div>
<div id="df-98550d14-c734-47d5-9dd6-cb4abd042838" class="colab-df-container">
<div>
<style scoped="">
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>train_loss</th>
<th>val_loss</th>
<th>val_acc</th>
</tr>
<tr>
<th>epoch</th>
<th></th>
<th></th>
<th></th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>0.907201</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>0</th>
<td>NaN</td>
<td>0.928417</td>
<td>0.7052</td>
</tr>
<tr>
<th>1</th>
<td>0.904668</td>
<td>NaN</td>
<td>NaN</td>
</tr>
<tr>
<th>1</th>
<td>NaN</td>
<td>0.876187</td>
<td>0.7232</td>
</tr>
<tr>
<th>2</th>
<td>0.812540</td>
<td>NaN</td>
<td>NaN</td>
</tr>
</tbody>
</table>
</div>
<div class="colab-df-buttons">
<div class="colab-df-container">
<button class="colab-df-convert" onclick="convertToInteractive('df-98550d14-c734-47d5-9dd6-cb4abd042838')" title="Convert this dataframe to an interactive table." style="display:none;"> <svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960"> <path d="M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z"></path> </svg>
</button> <style>
.colab-df-container {
display:flex;
gap: 12px;
}
.colab-df-convert {
background-color: #E8F0FE;
border: none;
border-radius: 50%;
cursor: pointer;
display: none;
fill: #1967D2;
height: 32px;
padding: 0 0 0 0;
width: 32px;
}
.colab-df-convert:hover {
background-color: #E2EBFA;
box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);
fill: #174EA6;
}
.colab-df-buttons div {
margin-bottom: 4px;
}
[theme=dark] .colab-df-convert {
background-color: #3B4455;
fill: #D2E3FC;
}
[theme=dark] .colab-df-convert:hover {
background-color: #434B5C;
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);
filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));
fill: #FFFFFF;
}
</style>
<script>
const buttonEl =
document.querySelector('#df-98550d14-c734-47d5-9dd6-cb4abd042838 button.colab-df-convert');
buttonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
async function convertToInteractive(key) {
const element = document.querySelector('#df-98550d14-c734-47d5-9dd6-cb4abd042838');
const dataTable =
await google.colab.kernel.invokeFunction('convertToInteractive',
[key], {});
if (!dataTable) return;
const docLinkHtml = 'Like what you see? Visit the ' +
'<a target="_blank" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'
+ ' to learn more about interactive tables.';
element.innerHTML = '';
dataTable['output_type'] = 'display_data';
await google.colab.output.renderOutput(dataTable, element);
const docLink = document.createElement('div');
docLink.innerHTML = docLinkHtml;
element.appendChild(docLink);
}
</script>
</div>
<div id="df-f8c857dc-1d85-4aa2-827a-1d2b431ea509">
<button class="colab-df-quickchart" onclick="quickchart('df-f8c857dc-1d85-4aa2-827a-1d2b431ea509')" title="Suggest charts." style="display:none;"> <svg xmlns="http://www.w3.org/2000/svg" height="24px"viewBox="0 0 24 24" width="24px"> <g> <path d="M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z"></path> </g>
</svg> </button> <style>
.colab-df-quickchart {
background-color: #E8F0FE;
border: none;
border-radius: 50%;
cursor: pointer;
display: none;
fill: #1967D2;
height: 32px;
padding: 0 0 0 0;
width: 32px;
}
.colab-df-quickchart:hover {
background-color: #E2EBFA;
box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);
fill: #174EA6;
}
[theme=dark] .colab-df-quickchart {
background-color: #3B4455;
fill: #D2E3FC;
}
[theme=dark] .colab-df-quickchart:hover {
background-color: #434B5C;
box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);
filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));
fill: #FFFFFF;
}
</style>
<script>
async function quickchart(key) {
const charts = await google.colab.kernel.invokeFunction(
'suggestCharts', [key], {});
}
(() => {
let quickchartButtonEl =
document.querySelector('#df-f8c857dc-1d85-4aa2-827a-1d2b431ea509 button');
quickchartButtonEl.style.display =
google.colab.kernel.accessAllowed ? 'block' : 'none';
})();
</script>
</div>
</div>
</div>
<p><img src="/assets/images/2023-08-20-lora_files/2023-08-20-lora_25_10.png" /></p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing: 0it [00:00, ?it/s]
</code></pre></div></div>
<div style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace" class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ test_acc │ 0.7870000004768372 │
│ test_loss │ 0.6841080188751221 │
└───────────────────────────┴───────────────────────────┘
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[{'test_loss': 0.6841080188751221, 'test_acc': 0.7870000004768372}]
</code></pre></div></div>
<p>The test accuracy is around 0.79, which is just 0.02 points shy of the score achieved by the full finetuning baseline model.</p>
<p>One would expect LoRA to more closely match the performance of the baseline with larger rank. Let’s continue the experiment with different ranks to verify this hypothesis. Below, we repeat the experiment with ranks 1, 2, 4, 8, 16, 32, and 64. Logs from Lightning are omitted.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">run_lora</span><span class="p">(</span><span class="n">rank</span><span class="p">):</span>
<span class="n">lora_model</span> <span class="o">=</span> <span class="n">MNISTLoRAModel</span><span class="p">(</span><span class="n">rank</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">state_dict</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="s">"model.pt"</span><span class="p">)</span>
<span class="n">lora_model</span><span class="p">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">state_dict</span><span class="p">,</span> <span class="n">strict</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="n">datamodule</span> <span class="o">=</span> <span class="n">MNISTDataModule</span><span class="p">()</span>
<span class="n">lora_trainer</span> <span class="o">=</span> <span class="n">pl</span><span class="p">.</span><span class="n">Trainer</span><span class="p">(</span>
<span class="n">accelerator</span><span class="o">=</span><span class="s">"auto"</span><span class="p">,</span>
<span class="n">devices</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">max_epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">lora_trainer</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">lora_model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)</span>
<span class="k">return</span> <span class="n">lora_trainer</span><span class="p">.</span><span class="n">test</span><span class="p">(</span><span class="n">lora_model</span><span class="p">,</span> <span class="n">datamodule</span><span class="o">=</span><span class="n">datamodule</span><span class="p">)[</span><span class="mi">0</span><span class="p">][</span><span class="s">"test_acc"</span><span class="p">]</span>
<span class="n">ranks</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">64</span><span class="p">]</span>
<span class="n">test_accs</span> <span class="o">=</span> <span class="p">[</span><span class="n">run_lora</span><span class="p">(</span><span class="n">rank</span><span class="p">)</span> <span class="k">for</span> <span class="n">rank</span> <span class="ow">in</span> <span class="n">ranks</span><span class="p">]</span>
</code></pre></div></div>
<p>Plotting the results, we see that LoRA indeed reaches the baseline when we give it full rank of 64; in fact, when rank is 64, the number of trainable parameters exceeds that of the baseline since we essentially have two full rank matrices instead of 1, and we see that LoRA outperforms the baseline. It is clear that as the rank increases, LoRA’s performance more closely matches that of the baseline.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">ranks</span><span class="p">,</span> <span class="n">test_accs</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">axhline</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="mf">0.809499979019165</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">"r"</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">"--"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"Rank"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"Accuracy"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="/assets/images/2023-08-20-lora_files/2023-08-20-lora_29_0.png" /></p>
<h1 id="conclusion">Conclusion</h1>
<p>In this post, we explored LoRA, a parameter efficient finetuning methodology. The beauty of LoRA is in its simplicity—it is motiviated by a simple heuristic, and it is relatively straightforward to implement in practice. LoRA has been applied to a variety of architectures, including LLMs and <a href="https://huggingface.co/blog/lora">Stable Diffusion</a>. This is in part because LoRA has been primarily battle-tested in self-attention modules, which are used in both LLMs and text-to-image models. Through this experiment, we also saw that rank is an important hyperparameter that effectively represents a tradeoff between model performance and computational cost: the higher the rank, the larger the number of parameters.</p>
<p>LoRA was further improved and explored in follow-up papers such as <a href="https://arxiv.org/abs/2305.14314">QLoRA: Efficient Finetuning of Quantized LMs</a> by Dettmers et al. These developments have really contributed to the democratization of LLMs: people can now consider finetuning LMs on consumer-grade GPUs from the comfort of their homes. This year, we have also seen an exponential number of Llama variants too many to name, which has invigorated the open source community to reproduce, match, and sometimes even outperform closed models like GPT-3.5 or GPT-4. This is an exciting development, and I am excited to see more breakthroughs to come.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>I’m still not sure about the “right” way of capitalizing Llama. In the <a href="https://arxiv.org/abs/2302.13971">original paper</a>, the model was written as “LLaMA.” However, in the <a href="https://arxiv.org/abs/2307.09288">most recent paper</a>, the same authors opted for a simplified convention, “Llama.” I’m going with the second version since it is simpler and more recent. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Jake TaeI recently completed another summer internship at Meta (formerly Facebook). I was surprised to learn that one of the intern friends I met was an avid reader of my blog. Encouraged by the positive feedback from my intern friends, I decided to write another post before the end of summer. This post is dedicated to the mandem: Yassir, Amal, Ryan, Elvis, and Sam.Hacking Word Hunt2022-08-21T00:00:00+00:002022-08-21T00:00:00+00:00https://jaketae.github.io/study/wordhunt<blockquote>
<p>Update: The code was modified with further optimizations. In particular, instead of checking the trie per every DFS call, we update the trie pointer along the DFS call so that the trie does not have to be queried repeatedly.</p>
</blockquote>
<p>Recently, I started playing <a href="https://apps.apple.com/us/app/gamepigeon/id1124197642">Game Pidgeon</a> games with my girlfriend. We often play Word Hunt, where the objective is to find as many words as possible in a grid of English letters within 30 seconds.</p>
<p><img src="https://i.stack.imgur.com/JsxLT.jpg" alt="img" /></p>
<p>Being a non-native English speaker, I seldom score a win against my girlfriend; she often claims victory with significant margins. In a desparate attempt to level the playing field, and also inspired by a <a href="https://www.youtube.com/watch?v=sMDcdDczXDc">YouTube video on Word Hunt</a>, I decided to resort to computers and algorithms.</p>
<h1 id="brute-force-dfs">Brute Force DFS</h1>
<p>The goal of this project is to come up with as many valid word combinations as possible given a grid of letters. Since the game ascribes higher scores to longer sequences, the longer the words, the better. Most importantly, we need to find these solutions within 30 seconds.</p>
<p>A naïve brute-force approach would be to traverse the grid to recover all possible sequences of letters, then check if these letters are in a source-of-truth list of vocabulary. Concretely, we can use any graph traversal algorithm like DFS to explore the grid and use a Python set for all English words to achieve amortized $O(1)$ lookup. Unfortunately, after a few iterations, I realized that this brute force approach is too inefficient given the 30 second time crunch.</p>
<h1 id="dfs-with-pruning-via-trie-lookup">DFS with Pruning via Trie Lookup</h1>
<p>One glaring inefficiency with the above approach is that we end up wastefully exploring infelicitous paths, i.e., paths which we already know will provide no solution. For instance, if we know ahead of time that there exists no word that starts with the prefix “xyz”, then there is no point in exploring “xyza” or “xyzb.” Instead, we can terminate the search and move onto paths where there is hope.</p>
<p>Unfortunately, the built-in Python set does not provide prefix lookup. Instead, a more suitable data structure is a <a href="https://en.wikipedia.org/wiki/Trie">trie</a>, also known as a prefix tree. A trie not only gives us speedy lookup, but also allows us to efficiently query words that start with a given prefix. If there is no word that starts with the prefix, we exit the search sequence, which effectively amounts to DFS backtracking with pruning.</p>
<h2 id="trie">Trie</h2>
<p>Python does not provide a built-in trie implementation. Although <a href="https://pypi.org/project/trie/">third-party packages</a> exist, I decided to implement my own.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Trie</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="bp">None</span><span class="p">:</span>
<span class="bp">self</span><span class="p">.</span><span class="n">root</span> <span class="o">=</span> <span class="p">{}</span>
<span class="bp">self</span><span class="p">.</span><span class="n">delimiter</span> <span class="o">=</span> <span class="s">"*"</span>
<span class="k">def</span> <span class="nf">insert</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">word</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="bp">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="p">.</span><span class="n">contains</span><span class="p">(</span><span class="n">word</span><span class="p">):</span>
<span class="k">return</span>
<span class="n">pointer</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">root</span>
<span class="n">word</span> <span class="o">+=</span> <span class="bp">self</span><span class="p">.</span><span class="n">delimiter</span>
<span class="k">for</span> <span class="n">char</span> <span class="ow">in</span> <span class="n">word</span><span class="p">:</span>
<span class="k">if</span> <span class="n">char</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">pointer</span><span class="p">:</span>
<span class="n">pointer</span><span class="p">[</span><span class="n">char</span><span class="p">]</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">pointer</span> <span class="o">=</span> <span class="n">pointer</span><span class="p">[</span><span class="n">letter</span><span class="p">]</span>
</code></pre></div></div>
<p>Internally, this trie implementation uses a nested dictionary to store words as a sequence of letters. We use an asterisk to mark the end of a word. For instance, adding the word “cat” to an empty trie will yield the following result:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="kn">from</span> <span class="nn">trie</span> <span class="kn">import</span> <span class="n">Trie</span>
<span class="o">>>></span> <span class="n">t</span> <span class="o">=</span> <span class="n">Trie</span><span class="p">()</span>
<span class="o">>>></span> <span class="n">t</span><span class="p">.</span><span class="n">insert</span><span class="p">(</span><span class="s">"cat"</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">t</span><span class="p">.</span><span class="n">trie</span>
<span class="p">{</span><span class="s">'c'</span><span class="p">:</span> <span class="p">{</span><span class="s">'a'</span><span class="p">:</span> <span class="p">{</span><span class="s">'t'</span><span class="p">:</span> <span class="p">{</span><span class="s">'*'</span><span class="p">:</span> <span class="p">{}}}}}</span>
</code></pre></div></div>
<p>Once we insert “car”, the “ca” prefix will be preserved, and we will see an additional “r” node.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">t</span><span class="p">.</span><span class="n">insert</span><span class="p">(</span><span class="s">"car"</span><span class="p">)</span>
<span class="o">>>></span> <span class="n">t</span><span class="p">.</span><span class="n">trie</span>
<span class="p">{</span><span class="s">'c'</span><span class="p">:</span> <span class="p">{</span><span class="s">'a'</span><span class="p">:</span> <span class="p">{</span><span class="s">'t'</span><span class="p">:</span> <span class="p">{</span><span class="s">'*'</span><span class="p">:</span> <span class="p">{}},</span> <span class="s">'r'</span><span class="p">:</span> <span class="p">{</span><span class="s">'*'</span><span class="p">:</span> <span class="p">{}}}}}</span>
</code></pre></div></div>
<p>Now that we have a trie, we can store the list of English words in this data structure. Quite simply, we read the text file and store its content in the trie.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">get_dictionary</span><span class="p">()</span> <span class="o">-></span> <span class="n">Trie</span><span class="p">:</span>
<span class="n">dictionary</span> <span class="o">=</span> <span class="n">Trie</span><span class="p">()</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s">"dictionary.txt"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="k">for</span> <span class="n">word</span> <span class="ow">in</span> <span class="n">f</span><span class="p">:</span>
<span class="n">word</span> <span class="o">=</span> <span class="n">word</span><span class="p">.</span><span class="n">strip</span><span class="p">()</span>
<span class="n">dictionary</span><span class="p">.</span><span class="n">insert</span><span class="p">(</span><span class="n">word</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dictionary</span>
</code></pre></div></div>
<h2 id="solving-word-hunt">Solving Word Hunt</h2>
<p>Now that the trie dictionary is ready, the next step is to traverse the board and retrieve all valid solutions. I took inspiration from DFS backtracking templates used to solve common problems, such as sudoku. For each cell in the game grid, we want to check for valid words that start with that cell. The <code class="language-plaintext highlighter-rouge">solve(grid)</code> function accepts a grid and calls the <code class="language-plaintext highlighter-rouge">traverse(...)</code> function to check for words starting at each index.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">List</span><span class="p">,</span> <span class="n">Tuple</span>
<span class="k">def</span> <span class="nf">solve</span><span class="p">(</span><span class="n">grid</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]])</span> <span class="o">-></span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]]:</span>
<span class="n">solutions</span> <span class="o">=</span> <span class="p">{}</span>
<span class="n">dictionary</span> <span class="o">=</span> <span class="n">get_dictionary</span><span class="p">()</span>
<span class="c1"># BOARD_SIZE == 4
</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">BOARD_SIZE</span><span class="p">):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">BOARD_SIZE</span><span class="p">):</span>
<span class="k">if</span> <span class="n">board</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="ow">in</span> <span class="n">dictionary</span><span class="p">.</span><span class="n">root</span><span class="p">:</span>
<span class="n">traverse</span><span class="p">(</span><span class="n">grid</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="s">""</span><span class="p">,</span> <span class="p">[],</span> <span class="n">solutions</span><span class="p">,</span> <span class="n">dictionary</span><span class="p">.</span><span class="n">root</span><span class="p">)</span>
<span class="k">return</span> <span class="n">solutions</span>
</code></pre></div></div>
<p>Although the function is named <code class="language-plaintext highlighter-rouge">solve(...)</code>, the actual heavy lifting is performed by the <code class="language-plaintext highlighter-rouge">traverse(...)</code> function, which recursively calls itself to perform DFS. Specifically, the <code class="language-plaintext highlighter-rouge">traverse(...)</code> function populates the <code class="language-plaintext highlighter-rouge">solutions</code> dictionary, which will contain valid words as keys and index sequences as values.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="nn">collections.abc</span> <span class="kn">import</span> <span class="n">Generator</span>
<span class="k">def</span> <span class="nf">get_neighbors</span><span class="p">(</span><span class="n">i</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">j</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-></span> <span class="n">Generator</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]:</span>
<span class="k">for</span> <span class="n">delta_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">):</span>
<span class="k">for</span> <span class="n">delta_j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">):</span>
<span class="k">if</span> <span class="n">delta_i</span> <span class="o">==</span> <span class="n">delta_j</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">continue</span>
<span class="n">next_i</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="n">delta_i</span>
<span class="n">next_j</span> <span class="o">=</span> <span class="n">j</span> <span class="o">+</span> <span class="n">delta_j</span>
<span class="k">if</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">next_i</span> <span class="o"><</span> <span class="n">BOARD_SIZE</span> <span class="ow">and</span> <span class="mi">0</span> <span class="o"><=</span> <span class="n">next_j</span> <span class="o"><</span> <span class="n">BOARD_SIZE</span><span class="p">:</span>
<span class="k">yield</span> <span class="p">(</span><span class="n">next_i</span><span class="p">,</span> <span class="n">next_j</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">traverse</span><span class="p">(</span>
<span class="n">grid</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]],</span>
<span class="n">i</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">j</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">word</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">order</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]],</span>
<span class="n">solutions</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">]]],</span>
<span class="n">pointer</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">:</span> <span class="n">Any</span><span class="p">],</span>
<span class="p">)</span> <span class="o">-></span> <span class="bp">None</span><span class="p">:</span>
<span class="n">char</span> <span class="o">=</span> <span class="n">grid</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span>
<span class="n">word</span> <span class="o">+=</span> <span class="n">char</span>
<span class="n">order</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">))</span>
<span class="n">prev</span> <span class="o">=</span> <span class="n">pointer</span>
<span class="n">pointer</span> <span class="o">=</span> <span class="n">pointer</span><span class="p">[</span><span class="n">char</span><span class="p">]</span>
<span class="k">if</span> <span class="s">"*"</span> <span class="ow">in</span> <span class="n">pointer</span><span class="p">:</span>
<span class="n">solutions</span><span class="p">[</span><span class="n">word</span><span class="p">]</span> <span class="o">=</span> <span class="n">order</span>
<span class="k">del</span> <span class="n">pointer</span><span class="p">[</span><span class="s">"*"</span><span class="p">]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">pointer</span><span class="p">:</span>
<span class="k">del</span> <span class="n">prev</span><span class="p">[</span><span class="n">char</span><span class="p">]</span>
<span class="k">return</span>
<span class="n">grid</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="bp">None</span>
<span class="k">for</span> <span class="n">next_i</span><span class="p">,</span> <span class="n">next_j</span> <span class="ow">in</span> <span class="n">get_neighbors</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">):</span>
<span class="k">if</span> <span class="p">(</span>
<span class="n">grid</span><span class="p">[</span><span class="n">next_i</span><span class="p">][</span><span class="n">next_j</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span>
<span class="ow">and</span> <span class="n">grid</span><span class="p">[</span><span class="n">next_i</span><span class="p">][</span><span class="n">next_j</span><span class="p">]</span> <span class="ow">in</span> <span class="n">pointer</span>
<span class="p">):</span>
<span class="n">traverse</span><span class="p">(</span><span class="n">grid</span><span class="p">,</span> <span class="n">next_i</span><span class="p">,</span> <span class="n">next_j</span><span class="p">,</span> <span class="n">word</span><span class="p">,</span> <span class="n">order</span><span class="p">.</span><span class="n">copy</span><span class="p">(),</span> <span class="n">solutions</span><span class="p">,</span> <span class="n">pointer</span><span class="p">)</span>
<span class="n">grid</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">char</span>
</code></pre></div></div>
<p>To prevent the algorithm from visiting cells it has previously visited (it’s illegal to duplicate a character by revisiting a letter we’ve already used in the current sequence), we mark the visited cell as <code class="language-plaintext highlighter-rouge">None</code> and recursively call <code class="language-plaintext highlighter-rouge">traverse(...)</code> on neighboring cells, which is obtained via <code class="language-plaintext highlighter-rouge">get_neighbors(i, j)</code>. Once all paths have been consumed, we unmark the cell back to its original value. This marking and unmarking is at the heart of backtracking. Notice that the implicit base case for this function is if no neighbors exist.</p>
<p>Also worthy of note is the use of the <code class="language-plaintext highlighter-rouge">dictionary</code> trie. The <code class="language-plaintext highlighter-rouge">return</code> in the middle of the function is where pruning occurs: if there is no word that starts with <code class="language-plaintext highlighter-rouge">word</code> as its prefix, there is no need to further venture down this path. Moreover, if <code class="language-plaintext highlighter-rouge">word</code> itself is in the vocabulary, we add it to <code class="language-plaintext highlighter-rouge">solutions</code>. Note that it is possible that multiple paths exist for the same word, but since we don’t care which path, there is no need to record all of them.</p>
<h2 id="putting-it-all-together">Putting It All Together</h2>
<p>Now that we have all the core algorithms ready, all we need is a surface-level API that will allow the user to interact with these functions. Although it would be nice to have a GUI component, for sake of simplicity I decided to make this a Python script. I also decided that the easiet way for a user to input the grid to the script is in <a href="https://en.wikipedia.org/wiki/Raster_scan">raster scan</a> order, which is a fancy way of saying left to right, top to bottom. Therefore, the 2D grid would be flattened to a line of 16 characters. Internally, we still want to parse the board as a grid: hence the <code class="language-plaintext highlighter-rouge">make_grid(board)</code> function, where <code class="language-plaintext highlighter-rouge">board</code> is the line of 16 characters inputted by the user.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">make_grid</span><span class="p">(</span><span class="n">board</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]]:</span>
<span class="n">grid</span> <span class="o">=</span> <span class="p">[[]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">BOARD_SIZE</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">char</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">board</span><span class="p">):</span>
<span class="n">grid</span><span class="p">[</span><span class="n">i</span> <span class="o">//</span> <span class="n">BOARD_SIZE</span><span class="p">].</span><span class="n">append</span><span class="p">(</span><span class="n">char</span><span class="p">)</span>
<span class="k">return</span> <span class="n">grid</span>
</code></pre></div></div>
<p>Now we are truly done! All we need is to (1) create the grid, (2) call the <code class="language-plaintext highlighter-rouge">solve(grid)</code> function, and (3) sort answers by word length and print them in order to the user.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">main</span><span class="p">(</span><span class="n">board</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-></span> <span class="bp">None</span><span class="p">:</span>
<span class="n">grid</span> <span class="o">=</span> <span class="n">make_grid</span><span class="p">(</span><span class="n">board</span><span class="p">)</span>
<span class="n">solutions</span> <span class="o">=</span> <span class="n">solve</span><span class="p">(</span><span class="n">grid</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">word</span><span class="p">,</span> <span class="n">order</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span>
<span class="nb">sorted</span><span class="p">(</span><span class="n">solutions</span><span class="p">.</span><span class="n">items</span><span class="p">(),</span> <span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">reverse</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="n">SHOW_TOP_K</span><span class="p">:</span>
<span class="k">break</span>
<span class="k">print</span><span class="p">(</span><span class="n">word</span><span class="p">,</span> <span class="n">order</span><span class="p">)</span>
<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="s">"__main__"</span><span class="p">:</span>
<span class="n">board</span> <span class="o">=</span> <span class="nb">input</span><span class="p">()</span>
<span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">board</span><span class="p">)</span> <span class="o">==</span> <span class="mi">16</span>
<span class="n">main</span><span class="p">(</span><span class="n">board</span><span class="p">)</span>
</code></pre></div></div>
<p>Here is a sample top-10 result with the example board shown at the very beginning of this blog post.</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>jaketae:wordhunt $ python main.py
oatrihpshtnrenei
haptene [(1, 1), (0, 1), (1, 2), (2, 1), (3, 2), (3, 1), (3, 0)]
haptens [(1, 1), (0, 1), (1, 2), (2, 1), (3, 2), (2, 2), (1, 3)]
pterins [(1, 2), (2, 1), (3, 2), (2, 3), (3, 3), (2, 2), (1, 3)]
staithe [(1, 3), (0, 2), (0, 1), (1, 0), (2, 1), (2, 0), (3, 0)]
tenners [(2, 1), (3, 0), (3, 1), (2, 2), (3, 2), (2, 3), (1, 3)]
tapnet [(0, 2), (0, 1), (1, 2), (2, 2), (3, 2), (2, 1)]
hapten [(1, 1), (0, 1), (1, 2), (2, 1), (3, 2), (3, 1)]
pterin [(1, 2), (2, 1), (3, 2), (2, 3), (3, 3), (2, 2)]
staith [(1, 3), (0, 2), (0, 1), (1, 0), (2, 1), (2, 0)]
sprent [(1, 3), (1, 2), (2, 3), (3, 2), (3, 1), (2, 1)]
</code></pre></div></div>
<p>There is no way I would have come up with some of these words.</p>
<h1 id="conclusion">Conclusion</h1>
<p>Today, we have seen one very practical application of algorithms: beating your girlfriend in Word Hunt. While the real test is to use this script in a game against her, preliminary results appear promising.</p>
<p>I hope you enjoyed reading this post. See you in the next one!</p>Jake TaeUpdate: The code was modified with further optimizations. In particular, instead of checking the trie per every DFS call, we update the trie pointer along the DFS call so that the trie does not have to be queried repeatedly.Glow-TTS2022-04-11T00:00:00+00:002022-04-11T00:00:00+00:00https://jaketae.github.io/study/glowtts<p><em>Note: This blog post was completed as part of Yale’s CPSC 482: Current Topics in Applied Machine Learning.</em></p>
<p>“Turn right at 130 Prospect Street.”</p>
<p>If you’ve used Google maps before, you will recall the familiar, smooth voice of the navigation assistant. At first glance, the voice appears to be a simple replay of human recordings. However, you will quickly realize that it is impossible to record the names of millions of streets, not to mention the billions of driving context in which they can appear.</p>
<p>Modern software, such as Google maps or voice assistant, are powered by neural text-to-speech (TTS), a powerful technology that synthesize human-sounding voices using machine learning. In this blog post, we will dive deep into a NeurIPS 2020 paper <a href="https://arxiv.org/abs/2005.11129">Glow-TTS: A Generative Flow for Text-to-Speech via Monotonic Alignment Search</a>, which demonstrates one of the many ways in which deep neural networks can be used for natural TTS.</p>
<h2 id="neural-text-to-speech">Neural Text-to-Speech</h2>
<p>Moden neural TTS pipelines are typically composed of two components: an accoustic feature generator and a vocoder. The acoustic feature generator accepts text as input and outputs an acoustic representation, such as a mel-spectrogram. The second stage of the pipeline, neural vocoders accept acoustic representations as input and outputs raw waveform. More generally, let $f$ and $g$ denote an acoustic feature generator and vocoder. Given an input text $T$, neural TTS can be understood as a composite function that outputs a waveform $W$ via</p>
\[\begin{aligned}
&X = f(T) \\
&W = g(X),
\end{aligned}\]
<p>where $X$ denotes the intermediate acoustic representation. Schematically, $g \cdot f$ fully defines the two-stage TTS process.</p>
<p>In this blog post, we will explore the first stage of the pipeline, the acoustic feature generator, exmplified by Glow-TTS. This post will proceed as follows. Firstly, we discuss generative flow models, which is the first core component of Glow-TTS. Secondly, we discuss the monotonic alignment search algorithm. Thirdly, we discuss the Glow-TTS pipeline as a whole by putting flow and MAS into a single picture. Last but not least, we conclude by considering some of the limitations of Glow-TTS and refer to more recent literature that points to exciting directions in the field of neural TTS.</p>
<h2 id="flow">Flow</h2>
<p>Text-to-speech is a conditional generative task, in which a model is given a sequence of tokens and produces a stream of utterance that matches the input text. Many neural TTS models employ generative models at their core, such as GANs, VAEs, transformers, or diffision models, often borrowing from breakthroughs in other domains such as computer vision.</p>
<h3 id="change-of-variables">Change of Variables</h3>
<p>Glow-TTS is based on normalizing flow, which is a class of well-studied generative models. The theoretical basis of normalizing flows is the change of variables formula. Let $\mathbf{X}$ and $\mathbf{Y}$ denote random variables, each with PDF $f_\mathbf{X}$ and $f_\mathbf{Y}$, respectively. Let $h$ denote some invertible transformation such that $\mathbf{Y} = h(\mathbf{X})$. Typically, $f_\mathbf{X}$ is a simple, tractable prior distribution, such as a standard Gaussian, and we seek to apply $h$ to model some more complicated distribution given by $\mathbf{Y}$. Then, the change of variables formula states that</p>
\[\begin{aligned}
f_\mathbf{Y}(\mathbf{y})
&= f_\mathbf{X}(\mathbf{x}) \bigg| \text{det} \frac{d \mathbf{x}}{d \mathbf{y}} \bigg| \\
&= f_\mathbf{X}(h^{-1}(\mathbf{y})) \bigg| \det \frac{d \mathbf{x}}{d \mathbf{y}} \bigg| \\
&= f_\mathbf{X}(h^{-1}(\mathbf{y})) \bigg| \det \frac{d h^{-1}(\mathbf{y})}{d \mathbf{y}} \bigg|,
\end{aligned}\]
<p>where $\det$ denotes the determinant and the derivative term represents the Jacobian.</p>
<p>A variation of this formula that allows for sampling from the base distribution can be written as follows:</p>
\[\begin{aligned}
f_\mathbf{Y}(\mathbf{y})
&= f_\mathbf{X}(\mathbf{x}) \bigg| \det \frac{d h^{-1} \mathbf{y}}{d \mathbf{y}} \bigg| \\
&= f_\mathbf{X}(\mathbf{x}) \bigg| \det \left( \frac{d h(\mathbf{x})}{d \mathbf{x}} \right)^{-1} \bigg| \\
&= f_\mathbf{X}(\mathbf{x}) \bigg| \det \frac{d h(\mathbf{x})}{d \mathbf{x}} \bigg|^{-1}.
\end{aligned}\]
<p>The intuition behind the change of variables formula is that the probability mass of an interval in $\mathbf{X}$ should remain unchanged in the transformed $\mathbf{Y}$ space. The determinant of the Jacobian is a corrective term that accounts for the slope or the “sensitivity” of the transformation given by $h$.</p>
<h3 id="maximum-likelihood">Maximum Likelihood</h3>
<p>Normalizing flow models can then be understood as a collection of nested invertible transformations, i.e., $h_1 \cdot h_2 \cdots h_n$, where $n$ denotes the number of flow layers in the model.<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> To better understand what this composite transformation achieves, let’s apply a logarithm to the change of variable formula.</p>
\[\log f_\mathbf{Y} (\mathbf{y}) = \log f_\mathbf{X} (\mathbf{x}) - \log \bigg| \det \frac{d h(\mathbf{x})}{d \mathbf{x}} \bigg|.\]
<p>To simplify notation, let $p_i$ denote the PDF of the $i$-th random variable in the composite transformation. Then, the nested transformation can be expressed as</p>
\[\begin{aligned}
\log f_n(\mathbf{x}_n)
&= \log f_{n - 1}(\mathbf{x}_{n - 1}) - \log \bigg| \det \frac{d h(\mathbf{x}_{n - 1})}{d \mathbf{x}_{n - 1}} \bigg| \\
&= \log f_{n - 2}(\mathbf{x}_{n - 2}) - \log \bigg| \det \frac{d h(\mathbf{x}_{n - 1})}{d \mathbf{x}_{n - 1}} \bigg| - \log \bigg| \det \frac{d h(\mathbf{x}_{n - 2})}{d \mathbf{x}_{n - 2}} \bigg| \\
&= \cdots \\
&= \log f_0(\mathbf{x}_0) - \sum_{i = 1}^n \log \bigg| \det \frac{d h(\mathbf{x}_i)}{d \mathbf{x}_i} \bigg|.
\end{aligned}\]
<p>The immediate implication of this exposition is that a repeated application of the change of variables formula provides a direct way of computing the likelihood of an observation from some complex, real-data distribution $f_n$ given a prior $f_0$ and a set of invertible transformation $h_1, h_2, \dots, h_n$. This conclusion illustrates the power of normalizing flows: it offers a direct way of measuring the likelihood of complex, high-dimensional data, such as ImageNet images, starting from a simple distribution, such as an isotropic Gaussian. Since the likelihood can directly be obtained, flow models are trained to maximize the log likelihood, which is exactly the expression derived above.</p>
<h3 id="affine-coupling">Affine Coupling</h3>
<p>Although direct likelihood computation is a marked advantage of flow over other generative models, it comes with two clear limitations:</p>
<ul>
<li>All transformations must be invertible.</li>
<li>The determinant of the Jacobian must be easily computable.</li>
</ul>
<p>A number of methods have been proposed to satisfy these constraints. One of the most popular method is the affine coupling layer. Let $d$ denote the cardinality of the embedding space. Given an input $\mathbf{x}$ and and output $\mathbf{z}$, the affine coupling layer can schematically be written as</p>
\[\begin{aligned}
\mathbf{z}_{1:d/2} &= \mathbf{x}_{1:d/2} \\
\mathbf{z}_{d/2:d}
&= \mathbf{x}_{d/2:d} \odot s_\theta(\mathbf{x}_{1:d/2}) + t_\theta(\mathbf{x}_{1:d/2}) \\
&= \mathbf{x}_{d/2:d} \odot s_\theta(\mathbf{z}_{1:d/2}) + t_\theta(\mathbf{z}_{1:d/2}).
\end{aligned}\]
<p>In other words, the affine coupling layer implements a special transformation in which the top half of $\mathbf{z}$ is simply copied from $\mathbf{x}$ without modification. The bottom half undergoes an affine transformation, where the weights and biases are computed from the top half of $\mathbf{x}$. We can easily check that this transformation is indeed invertible:</p>
\[\begin{aligned}
\mathbf{x}_{1:d/2} &= \mathbf{z}_{1:d/2} \\
\mathbf{x}_{d/2:d} &= s_\theta^{-1}(\mathbf{z}_{1:d/2})(\mathbf{z}_{d/2:d} - t_\theta(\mathbf{z}_{1:d/2}))
\end{aligned}.\]
<p>Coincidentally, the affine coupling layer is not only invertible, but it also enables efficient computation of the Jacobian determinant. This comes from the fact that the top half of the input is unchanged.</p>
\[\begin{align}
\mathbf{J}
&= \begin{pmatrix} \frac{d \mathbf{z}_{1:d/2}}{d \mathbf{x}_{1:d/2}} & \frac{d \mathbf{z}_{1:2/d}}{d \mathbf{x}_{2/d:d}} \\ \frac{d \mathbf{z}_{2/d:d}}{d \mathbf{x}_{1:2/d}} & \frac{d \mathbf{z}_{d/2:d}}{d \mathbf{x}_{d/2:d}} \end{pmatrix} \\
&= \begin{pmatrix} \mathbb{I} & 0 \\ \frac{d \mathbf{z}_{2/d:d}}{d \mathbf{x}_{1:2/d}} & \text{diag}(s_\theta(\mathbf{x}_{1:d/2})) \end{pmatrix}.
\end{align}\]
<p>Although $\mathbf{J_{21}}$ contains complicated terms, we do not have to consider them when computing $\det \mathbf{J}$: the determinant of a lower triangular matrix is simply the product of its diagonal entries. Hence, $\det \mathbf{J} = \mathbf{J_{11}} \times \mathbf{J_{22}}$, which is computationally tractable.</p>
<p>In practice, flow layers take a slightly more complicated form than the conceptual architecture detailed above. One easy and necessary modification is to shuffle the indices that are unchanged at each layer; otherwise, the top half of the input representation would never be altered even after having passed through $n$ layers. Another sensible modification would be to apply a more complicated transformation. For example, <a href="https://arxiv.org/abs/1605.08803">Real NVP</a> proposes the following schema:</p>
\[\begin{aligned}
\mathbf{z}_{1:d/2} &= \mathbf{x}_{1:d/2} \\
h &= a \times \text{tanh}(s_\theta(\mathbf{x}_{1:d/2})) + b \\
\mathbf{z}_{d/2:d} &= \text{exp}(h) \times \mathbf{x}_{d/2:d} + g_\theta(\mathbf{x}_{1:d/2}).
\end{aligned}\]
<p>To summarize:</p>
<ul>
<li>Flow models are based on the change of variables formula, which offers a way of understanding the PDF of the transformed random variable.</li>
<li>Since flow models can directly compute the likelihood of the data distribution using a prior, it is trained to maximize the log likelihood of observed data.</li>
<li>Many architectures, such as affine coupling layers, have been proposed to fulfill the invertability and Jacobian determinant constraints of flow.</li>
</ul>
<p>Now that we have understood how flow works, let’s examine how flow is used in Glow-TTS.</p>
<h3 id="glow-tts">Glow-TTS</h3>
<p>Glow-TTS uses a flow-based decoder that transforms mel-spectrograms into a latent representation. As can be seen below in the architecture diagram, Glow-TTS accepts ground-truth mel-spectrograms (top of figure) and ground-truth text tokens (bottom of figure, shown as “a b c”) during training. Then, it runs the monotonic alignment search algorithm, which we will explore in the next section, to find an alignment between text and speech. The main takeaway is that the flow-based decoder transforms mel-spectrograms $\mathbf{y}$ to some latent vector $\mathbf{z}$, i.e., $f(\mathbf{y}) = \mathbf{z}$.</p>
<p><img src="https://production-media.paperswithcode.com/methods/Screen_Shot_2021-08-10_at_2.50.30_PM.png" /></p>
<p>At a glance, it might not be immediately clear why we might want to use a flow model for the decoder instead of, for instance, a CNN or a transformer. However, the inference procedure makes clear why we need a flow-based model as the decoder. To synthesize a mel-spectrogram during inference, we estimate latent representations from user input text, then pass it on to the decoder. Since the decoder is invertible, we can reverse flow through the decoder to obtain a prediced mel-spectrogram, i.e., $f^{-1}(\hat{\mathbf{z}}) = \hat{\mathbf{y}}$, where $\hat{\cdot}$ denotes a prediction (as opposed to a ground-truth). In Glow-TTS, invertability offers an intuitive, elegant way of switching from training to inference.</p>
<p>The part that remains unexplained is how the model learns the latent representations and the relationship between text and acoustic features. This is explained by monotonic alignment search, which is the main topic of the next section.</p>
<h2 id="monotonic-alignment-search">Monotonic Alignment Search</h2>
<p>Proposed by Kim et. al., Monotonic Alignment Search (MAS) is an algorithm for efficiently identifying the most likely alignment between speech and text.</p>
<p><img src="https://distill.pub/2017/ctc/thumbnail.jpg" /></p>
<p>Text-to-speech alignment refers to the correspondence between text and spoken audio. Consider a simple input, “hello!”, accompanied by a human recording of that sentence. We could imagine that the first 0.5 seconds of the audio corresponds to the first letter “h,” followed by 0.7 seconds of “e,” and so on. The process of attributing a specific text token to some time interval within the audio can be described as alignment search.</p>
<p>Finding an accurate alignment between speech and text is an incredibly important task in TTS. If an alignment discovered by the model is inaccurate, it could mean that the model skips words or repeats certain syllables, both of which are failure nodes we want to avoid. One of the most salient features of MAS is that it prevents such failures by preemptively enforcing very specific yet sensible inductive biases into the alignment search algorithm.</p>
<h3 id="inductive-biases">Inductive Biases</h3>
<p>Let’s begin by enumerating a list of common sense intuition we have about TTS alignments.</p>
<ul>
<li>The model should “read” from left to right in a linear fashion.</li>
<li>The model always begins with the first letter and ends on the last letter.</li>
<li>The model should not skip any text.</li>
<li>The model should not repeat any text.</li>
</ul>
<p>Many previous alignment search methods do not necessarily enforce these constraints. For instance, Tacotron 2 uses sequence-to-sequence RNN attention to autoregressively build the alignment between speech and text. However, autoregressive alignment search often fails when long input text are fed into the model since errors can accumulate throughout the text sequence, yielding a highly inaccurate alignment at the end of the iteration. On the other hand, MAS is not only non-autoregressive, but also designed specifically so that the discovered alignment will never violate the set of inductive biases outlined above. This makes the model much more robust, even when the input sequence length is arbitrairly long.</p>
<h3 id="dynamic-programming">Dynamic Programming</h3>
<p>At the heart of MAS is dynamic programming (DP), a common programming technique used to optimize runtime on problems that can be decomposed into recurring sub-problems that share the same structure as its parent. DP offers a reasonably efficient way of solving many problems, usually in $O(n^d)$ runtime, where $n$ is the size of the input and $d$ denotes DP dimensionality. While this section will not attempt to explain DP in full, we will consider a toy problem to motivate DP specifically in the context of MAS.</p>
<p>Consider a classic dynamic programming problem, where the goal is to find a monotonic path that maximizes the sum of scores given some score matrix. Here, “monotonic” means either moving from the current position diagonally down, or jumping to the right cell within the same row. While there might be many ways to approach this problem, here is one possible solution.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">copy</span>
<span class="k">def</span> <span class="nf">find_maximum_sum_path</span><span class="p">(</span><span class="n">scores</span><span class="p">):</span>
<span class="c1"># preliminary variables
</span> <span class="n">num_rows</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span>
<span class="n">num_cols</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">scores</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="c1"># copy to avoid overriding `scores`
</span> <span class="n">scores2</span> <span class="o">=</span> <span class="n">copy</span><span class="p">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span>
<span class="c1"># base case for first row
</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">):</span>
<span class="n">scores2</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="n">scores2</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">j</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span>
<span class="c1"># dynamic programming
</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_rows</span><span class="p">):</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">):</span>
<span class="n">scores2</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">+=</span> <span class="nb">max</span><span class="p">(</span><span class="n">scores2</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">][</span><span class="n">j</span> <span class="o">-</span> <span class="mi">1</span><span class="p">],</span> <span class="n">scores2</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span> <span class="o">-</span> <span class="mi">1</span><span class="p">])</span>
<span class="c1"># backtracking
</span> <span class="c1"># create `path` to return
</span> <span class="n">i</span> <span class="o">=</span> <span class="n">num_rows</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">path</span> <span class="o">=</span> <span class="p">[[</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_cols</span><span class="p">)]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_rows</span><span class="p">)]</span>
<span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">reversed</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_cols</span><span class="p">)):</span>
<span class="n">path</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">!=</span> <span class="mi">0</span> <span class="ow">and</span> <span class="p">(</span><span class="n">i</span> <span class="o">==</span> <span class="n">j</span> <span class="ow">or</span> <span class="n">scores2</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]</span> <span class="o"><</span> <span class="n">scores2</span><span class="p">[</span><span class="n">i</span> <span class="o">-</span> <span class="mi">1</span><span class="p">][</span><span class="n">j</span> <span class="o">-</span> <span class="mi">1</span><span class="p">]):</span>
<span class="n">i</span> <span class="o">-=</span> <span class="mi">1</span>
<span class="k">return</span> <span class="n">path</span>
</code></pre></div></div>
<p>Given the following <code class="language-plaintext highlighter-rouge">scores</code>, the function returns the following result:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">>>></span> <span class="n">grid</span> <span class="o">=</span> <span class="p">[</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span>
<span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
<span class="p">]</span>
<span class="o">>>></span> <span class="n">find_maximum_sum_path</span><span class="p">(</span><span class="n">grid</span><span class="p">)</span>
<span class="p">[</span>
<span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
<span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span>
<span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span>
<span class="p">]</span>
</code></pre></div></div>
<p>It is not difficult to perform a manual sanity to check that the returned result is indeed the path that maximizes the sum of scores while adhering to the monotonicity constraint.</p>
<h3 id="likelihood-scores">Likelihood Scores</h3>
<p>Let’s take a step back and revisit the model architecture diagram presented above. On the left side of the diagram, we see an illustration of monotonic alignment search in action. Notice that this is exactly the problem we solved above: given some matrix of scores, find a monotonic path that maximizes the sum. Now, only a few missing pieces remain:</p>
<ul>
<li>What is the matrix of scores?</li>
<li>How does this relate to the flow-based decoder?</li>
</ul>
<p>Turns out that the two questions are closely related, and answering one will shed light on the other.</p>
<p>Recall that Glow-TTS deals with two input modalities during training: a string of text and its corresponding mel-spectrogram. The mel-spectrogram is decoded through the flow-based decoder. Similarly, the text is fed to a text encoder network, which outputs $\mathbf{\mu}$ and $\mathbf{\sigma}$ for each token of text. In other words, given <code class="language-plaintext highlighter-rouge">["h", "e", "l", "l", "o"]</code>, we would have a total of five mean and standard deviation vectors corresponding to each letter.<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup> We can denote them as $\mathbf{\mu_1}, \mathbf{\mu_2}, \dots, \mathbf{\mu_5}$, and $\mathbf{\sigma_1}, \mathbf{\sigma_2}, \dots, \mathbf{\sigma_5}$. Let’s also assume in this example that the corresponding mel-spectrogram spans a total of 100 frames. The output of the flow decoder would also be 100 vectors, denoted as $\mathbf{z_1}, \mathbf{z_2}, \dots, \mathbf{z_{100}}$.</p>
<p>Using these quantities, we can then construct a likelihood score matrix $P \in \mathbb{R}^{5 \times 100}$. The entries of the probability score matrix are computed via $P_{ij} = \log(\phi(\mathbf{z_j}; \mu_i, \sigma_i))$, where $\phi$ denotes the normal probability density function. Since $\sigma$ is a vector instead of a matrix, we assume an isotropic Gaussian, i.e., the covariance matrix is diagonal. The intuition is that the value of $P_{ij}$ indicates how likely it is that the $i$-th character matches or aligns with the $j$-th mel-spectrogram frame. If the two pairs of text and audio match, the probability score will be high, and vice versa. Log likelihood is used so that summation of scores effectively models a product in probability space.</p>
<p>Given this context, we can now apply the solution to the monotonic path sum problem motivated in the previous section. Instead of some arbitrary <code class="language-plaintext highlighter-rouge">scores</code> matrix, we create the probability score matrix $P$ and use DP to discover the most likely monotonic alignment between speech and text. The alignment will satisfy the inductive biases we identified earlier due to the inherent design of MAS.</p>
<p>It is worth noting that MAS is a generic alignment search algorithm that is independent of the flow-based model design. In particular, MAS was used without the flow decoder in <a href="https://arxiv.org/abs/2105.06337">Grad-TTS</a>. Popov et. al. proposed using mel-spectrogram frames directly to measure the probability score given the mean and variance prediced from text. In other words, instead of using $\mathbf{z}$, mel-spectrogram frames $\mathbf{y}$ were used. Grad-TTS is notable in its use of score-based generative models, which fall under the larger category of diffusion-based probabilistic models.</p>
<h2 id="glow-tts-pipeline">Glow-TTS Pipeline</h2>
<p>We can finally put flow and MAS together to summarize the overall pipeline of Glow-TTS.</p>
<h3 id="training">Training</h3>
<p>Given a pair of text and mel-spectrogram $(T, \mathbf{y})$, we feed $T$ into the text encoder $f_\text{text}$ and mel-spectrogram $\mathbf{y}$ into the flow-based decoder $f_\text{mel}$ to obtain $f_\text{mel}(\mathbf{y}) \in \mathbb{R}^{D \times L_\text{mel}}$ and $f_\text{text}(T) = (\mu, \sigma)$, where $\mu, \sigma \in \mathbb{R}^{D \times L_\text{text}}$ and $D$ denotes the size of the embedding. We can then use MAS to obtain the most likely monotonic alignment $A^* \in \mathbb{R}^{L_\text{text} \times L_\text{mel}}$. Since Glow-TTS is a flow-based model, which enables direct computation of likelihood, the model is simply trained to maximize the value of the log-likelihood given by the sum of the entries of the log-likelihood score matrix $P$. $A^\star$ can intuitively be understood as a binary mask used to index $P$. Schematically, the final log-likelihood could be written as $l = \sum_{i = 1}^{L_\text{text}} \sum_{j = 1}^{L_\text{mel}}(P \odot A^\star)_{ij}$, where $\odot$ denotes a Hadamard product, or an element-wise product of matrices. Since optimization in modern machine learning are typically framed as a minimizing problems, we minimize the negative log-likelihood.</p>
<p>Although not discussed in the sections above, Glow-TTS requires training a small sub-model, called a duration predictor, for inference. Because we do not have access to the ground-truth mel-spectrogram during inference, we need a model that can predict the best alignment $A^*$ purely from text. This task is carried out by the duration predictor, which accepts $T$ as input and is trained to maximize the L2 distance between its predicted alignment $\hat{A}$ and the actual $A^\star$ discovered by MAS.</p>
<h3 id="inference">Inference</h3>
<p>In the context of inference, the model has to output a predicted mel-spectrogram $\hat{\mathbf{y}}$ conditioned on the input text $T$. First, we use the learned text encoder to obtain mean and variance, i.e., $f_\text{text}(T) = (\mu, \sigma)$. Then, we use the duration predictor to obtain a predicted alignment $\hat{A}$. We can then sample from the $\mathcal{N}(\mu, \sigma^2)$ distribution according to $\hat{A}$. Continuing the earlier example of <code class="language-plaintext highlighter-rouge">T = ["h", "e", "l", "l", "o"]</code>, let’s say that <code class="language-plaintext highlighter-rouge">A_star = [1, 3, 2, 1, 1]</code>. This means that we have to sample from $\mathcal{N}(\mu_\text{h}, \sigma_\text{h})$ once, $\mathcal{N}(\mu_\text{e}, \sigma_\text{e})$ three times, and so on. By concatenating the results of sampling, we obtain $\hat{\mathbf{z}} \in \mathbb{R}^{D \times \hat{L_\text{mel}}}$, where $\hat{L_\text{mel}}$ denotes the length of the predicted mel-spectrogram frames, which is effectively <code class="language-plaintext highlighter-rouge">sum(A_star)</code>. Once we have $\hat{\mathbf{z}}$, we finally use the flow decoder to invert it into the mel-spectrogram space, i.e., $f_\text{mel}^{-1}(\hat{\mathbf{z}}) = \hat{\mathbf{y}}$.</p>
<p>Sample diversity is an important concern in neural TTS. Just like humans can read a single sentence in many different ways by varying tone, pitch, and timbre, preferably, we want a TTS model to be able to produce diverse samples. One way to achieve this in Glow-TTS is by varying the temperature parameter during sampling. In practice, sampling is performed thorugh the reparametrization trick:</p>
\[\epsilon \sim \mathcal{N}(0, 1) \\
\mathbf{z} = \mu + \epsilon \cdot \sigma^2.\]
<p>Through listening tests and pitch contours, Kim et. al. show that varying $\epsilon$ achieves diversity among samples produced by Glow-TTS.</p>
<h3 id="results">Results</h3>
<p>A marked advantage of Glow-TTS is that it is a parallel TTS model. This contrasts with existing autoregressive baselines, such as Tacotron 2. While autoregressive models require an iterative loop to condition the output of the current timestep on that from the previous timestep, parallel models produce an output in a single pass. In other words, parallel models run in constant time, whereas the runtime complexity of autoregressive models scales linearly with respect to the length of the input sequence. This is clear in the comparison figure taken from the Glow-TTS paper.</p>
<p><img src="https://media.arxiv-vanity.com/render-output/5100370/x6.png" /></p>
<p>Another pitfall of autoregressive models is that errors can accumulate throughout the iterative loop. If the model misidentifies an alignment between speech and text early on in the input sequence, later alignments will also likely be incorrect. In the case of parallel models, error accumulation is not possible since there is no iterative loop to begin with. Moreover, alignments found by Glow-TTS are made even more robust due to the design of MAS, which systematically identifies only those alignments that satisfy the monotonicity inductive bias. In the figure below, also taken directly from the Glow-TTS paper, Kim et. al. show that the Glow-TTS maintains a consistent character error rate, while that of Tacotron 2 increases proportionally to the length of the input sequence.</p>
<p><img src="https://media.arxiv-vanity.com/render-output/5100370/x9.png" /></p>
<p>Glow-TTS achieves competitive results on mean opnion score (MOS) listening tests. MOS tests are typically performed by randomly sampling a number of people and providing them to rate an audio sample from a scale of 1 to 5, where higher is better.</p>
<p>In the results table shown below, GT (ground-truth) is rated most highly at 4.54. WaveGlow is a neural vocoder that transforms mel-spectrograms to waveform. GT (Mel + WaveGlow) received 4.19, marginally below the GT waveform score. This is because using a neural vocoder necessarily introduces quality degradations and artifacts. Since even the best neural TTS acoustic feature generator would not be able to produce a mel-spectrogram that sounds more natural than a human recording, 4.19 can be considered as the theoretical upperbound for any TTS model and WaveGlow combination. Glow-TTS comes pretty close to 4.19, scoring approximately 4 across various temperature parameters. While the difference of 0.19 certainly suggests room for improvement, it is worth mentioning that Glow-TTS outperforms the Tacotron 2, which has been considered the competitive SOTA TTS model for a long time.</p>
<p><img src="https://d3i71xaburhd42.cloudfront.net/4a028532ec2bd4930c5cb228aabae64f28def55f/6-Table1-1.png" /></p>
<h3 id="future-direction">Future Direction</h3>
<p>An emerging trend in neural TTS literature is end-to-end TTS modeling. Instead of the traditional two-stage pipeline composed of an acoustic feature generator and a neural vocoder, end-to-end models produce raw waveforms directly from text without going to the intermediate mel-spectral representation. One prime example is <a href="https://arxiv.org/abs/2106.06103">VITS</a>, an end-to-end speech model developed by the authors of Glow-TTS published in ICML 2021. VITS is a combination of Glow-TTS and <a href="https://arxiv.org/abs/2010.05646">HiFi-GAN</a>, which is a neuarl vocoder. VITS uses largely the same MAS algorithm as Glow-TTS, and uses a variational autoencoding training scheme to combine the feature generator and the neural vocoder.</p>
<p>A benefit of using end-to-end modeling is that the model is relieved of the mel-spectral information bottleneck. Mel-spectrogram is a specific representation of information defined and crafted according to human knowledge. However, the spirit of deep learning is that no manual hand-crafting of features is necessary, provided sufficient data and modeling capacity. End-to-end models allow the model to choose its own intermediate representation that best accomplishes the task of synthesizing natural-sounding audio. Indeed, VITS outperforms Tacotron and Glow-TTS by considerable margins and almost matches ground-truth MOS ratings. This is certainly an exciting development, and we can expect more lines of work in this direction.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Glow-TTS is a flow-based neural TTS model that demonstrated a method of leveraging the invertability of flow to produce mel-spectrograms from text-derived latent representations. By projecting mel-spectrograms and text into a common latent space and using MAS and maximum likelihood-based training, Glow-TTS is able to learn robust, hard monotonic alignments between speech and text. Similar to Tacotron 2, Glow-TTS is now considered a competitive baseline and is referenced in recent literature.</p>
<p>Neural TTS has seen exciting developments over the past few years, including general text-to-speech, voice cloning, singing voice synthesis, and prosody transfer. Moreover, given the rapid pace of development in other fields, such as natural language processing, automatic speech recognition, and multidmodal modeling, we could see more interesting models that combine different approaches and modalities to perform a wide array of complex tasks. If anything remains clear, it is that we are living at an exciting time in the era of machine learning, and that the next few years will continue to see breakthroughs and innovations that will awe and surprise us, just like people a few decades ago would marvel at the simplest words:</p>
<p>“Turn right at 130 Prospect Street.”</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>While there are variations of normalizing flows, such as continuous flows or neural ODEs, for sake of simplicity, we only consider discontinuous normalizing flow. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:2" role="doc-endnote">
<p>In practice, most TTS models, including Glow-TTS, use phonemes as input instead of characters of text. We illustrate the example using characters for simplicity. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Jake TaeNote: This blog post was completed as part of Yale’s CPSC 482: Current Topics in Applied Machine Learning.Reflections and Expectations2021-12-27T00:00:00+00:002021-12-27T00:00:00+00:00https://jaketae.github.io/blog/2021<p>Last year, I wrote a <a href="https://jaketae.github.io/blog/2021/">blog post reflecting on the year 2020</a>. Re-reading what I had written then was surprisingly insightful, particularly because I could see how life had changed in some ways and remained unchanged in others. I decided to continue the tradition this year in the hopes of presenting my year-later self with the same joy and delight of reading a memoir of similar kind.</p>
<p>2021 was, in some ways, very similar to 2020. Despite the development and proliferation of vaccines, COVID-19 raged on, morphing into a new variant every few months. Masks and social distancing are now deeply embedded into our daily lives. Although booster shots and pill-type medications might change the dynamics of the pandemic, I personally think COVID is here to stay, at least for the foreseeable future.</p>
<p>After being discharged from the army in March of 2021, I spent roughly 6 months working as an intern at <a href="https://neosapience.com">Neosapience</a>, a Korean startup specializing in voice-over services and metaverse characters. This was also when I left <a href="https://www.rerent.co">ReRent</a>, a hospitality startup that I was fortunate enough to have worked for since the summer of 2020. ReRent immensely helped me learn and grow as a software developer, versed in <code class="language-plaintext highlighter-rouge">git</code> and GitHub, general web development, and Django, which has since become my favorite Python backend framework. It is also where I met valuable teammates, some of whom I met in person at Yale.</p>
<p>The transition from ReRent to Neosapience was a lot more than just a change of jobs. At Neosapience, I worked on machine learning research–an art of its own entirely different from backend web development. Specifically, I was tasked with the job of developing a singing voice synthesis model that, given lyrics and melodies, could “sing.” I still remember the frustration I felt when I was first trying to reproduce a reference paper I was provided as a baseline. There were parts of the paper that were ambiguous. The fact that it was a GAN-based model certainly did not help. I reached out to the authors in the hopes of gaining clarity, but received no response. Although I extrapolated parts of the model and trained it for a few days, the model only produced barely audible mumbles that could not be farther from the act of singing. I learned that ML was hard.</p>
<p>Thankfully, I was fortunate enough to have had more experienced co-workers as mentors who provided valuable pieces of advice. One of them suggested that I design a model of my own instead of blindly trying to reproduce the paper. As a demo of sorts, he showed me that a simple CNN model could sing better than the GAN I was trying to reproduce, with just a few minutes of training. Inspired by his progress, I began designing my own modules to experiment with a host of different architectures: CNNs, RNNs, transformers, and combinations thereof. I also explored various famous CNN architectures, such as InceptionNet and ResNeXT in search of inspiration and ideas.</p>
<p>Unexpectedly, the biggest success came from a very experimental model that was a direct adaptation of <a href="https://arxiv.org/abs/2105.01601">MLP-Mixer</a>, an architecture composed entirely of multi-layer perceptrons, or <code class="language-plaintext highlighter-rouge">nn.Linear</code> layers in PyTorch. This was a paper I presented during one of our weekly paper-reading meetings. Although the quality of results produced by the final model still contained audible artifacts, nonetheless we saw novelty in the fact that it was the first voice synthesis model exclusively composed of linear layers. This project culminated in my first ever publication <a href="https://arxiv.org/abs/2106.07886">MLP Singer: Towards Rapid Parallel Korean Singing Voice Synthesis</a> in <a href="https://2021.ieeemlsp.org">IEEE Machine Learning for Signal Processing workshop</a>, now available on <a href="https://ieeexplore.ieee.org/document/9596184">IEEE Xplore</a>. By the end of my internship, I felt a lot more comfortable with various ML concepts and their implementations. This was also when I was involved with Hugging Face’s Flax/JAX community week event where my teammates and I developed <a href="https://github.com/jaketae/koclip">KoCLIP</a>, as well as <a href="https://bigscience.huggingface.co">BigScience</a>, a huge project by Hugging Face to reproduce a GPT-3-sized language model.</p>
<p>I came back to Yale with the explicit intent of majoring in Computer Science and Mathematics. While this was not a trivial decision, it was very clear and obvious to me that this was the academic path I wanted to pursue. I took CPSC 223, which is Yale’s signature data structures course taught in… barebones C. <code class="language-plaintext highlighter-rouge">malloc</code> and <code class="language-plaintext highlighter-rouge">free</code> are probably the functions I used the most this year, perhaps with the exception of <code class="language-plaintext highlighter-rouge">print</code>/<code class="language-plaintext highlighter-rouge">printf</code>s I used for lazy debugging. On top of CS classes, I also continued my involvement with ML in a few ways. For one thing, I co-authored my second paper, <a href="https://arxiv.org/abs/2110.02584">EdiTTS: Score-based Editing for Controllable Text-to-Speech</a>, with a co-worker at Neosapience. This was the first project in which I used Amazon Mechanical Turk for MOS measurements. I’m still waiting on the final decision from a conference to which I submitted this paper, but I’m happy about how it came out regardless.</p>
<p>More importantly, I was extremely fortunate to be given the opportunity to work as a software engineering intern at Hugging Face. This was an unbelievable achievement for me that I knew I did not deserve. As a self-taught newcomer and student to the field of ML, I only dreamed about working at Hugging Face when I was first learning about transformers. I still have not produced much output at HF largely due to the fact that my internship was part-time and very low time commitment-wise, but I’m still excited for the month of January, which is when I will be dedicating myself full time to Hugging Face and BigScience. I would also like to express gratitude to the engineer at Hugging Face who referred me to this position, and whom I now consider a mentor, <a href="https://twitter.com/stasbekman">Stas Bekman</a>.</p>
<p>This semester was perhaps the hardest one yet at Yale. All the classes I took either required a lot of effort or time commitment. Admittedly to fulfill my distribution requirement, I went out my ways and took HIST 271: European Intellectual History since Nietzsche, where I learned a ton about philosophy, from the Enlightenment all the way up to post-Modernism. I also enrolled in ASTR 110: Planets and Stars, which I frankly took for an easy science credit, only to realize that weekly problem sets took up more time than I had anticipated. MATH 241: Probability Theory was easy at first, but ramped up quite quickly at the end of the semester, to the point that I was floundering about during finals week. Nonetheless, I’m glad that the semester is over, and that I came out of it feeling more learned and knowledgable than I was five months ago.</p>
<p>2021 was surely a roller coaster ride. It was surely a fruitful one, but it is also a miracle how it turned out the way it did. With experience, memories, and gratefulness at heart, I cannot wait to see what 2022 has in store.</p>Jake TaeLast year, I wrote a blog post reflecting on the year 2020. Re-reading what I had written then was surprisingly insightful, particularly because I could see how life had changed in some ways and remained unchanged in others. I decided to continue the tradition this year in the hopes of presenting my year-later self with the same joy and delight of reading a memoir of similar kind.Score Matching2021-12-26T00:00:00+00:002021-12-26T00:00:00+00:00https://jaketae.github.io/study/sliced-score-matching<p>Recently, I’ve heard a lot about score-based networks. In this post, I will attempt to provide a high-level overview of what scores are and how the concept of score matching gives rise to a family of likelihood-based generative models. This post is heavily adapted from <a href="https://yang-song.github.io/blog/2019/ssm/">Yang Song’s post on sliced score matching</a>.</p>
<h1 id="probability-model">Probability Model</h1>
<p>Given a parametrized real-valued function $f_\theta(\mathbf{x})$, we can derive a probability model $p_\theta(\mathbf{x})$ by applying a normalization term $Z_\theta$.</p>
\[p_\theta (\mathbf{x}) = \frac{e^{- f_\theta (\mathbf{x})}}{Z_\theta} \\
Z_\theta = \int e^{- f_\theta (\mathbf{x})} \, d \mathbf{x}.\]
<p>In practice, $f_\theta$ is often an energy-based model (EBM).</p>
<p>We can then define the likelihood function as follows:</p>
\[\log p_\theta (\mathbf{x}) = - f_\theta (\mathbf{x}) - \log Z_\theta.\]
<p>However, one glaring problem with this formulation is that $Z_\theta$ is often intractable. Score-matching presents an elegant solution to bypass this problem.</p>
<h1 id="score-matching">Score-Matching</h1>
<p>To eliminate the intractable term, we consider the score, which is defined as the gradient of the log likelihood with respect to the random variable $\mathbf{x}$. Note that we are not taking the gradient with respect to the parameter $\theta$, which is typically the object of interest in processes such as MLE.</p>
\[\nabla_\mathbf{x} \log p_\theta (\mathbf{x}) = - \nabla_\mathbf{x} f_\theta (\mathbf{x}).\]
<p>The goal of score-matching, then, is to minimize the difference between $p_\text{data}$ and $p_\theta$ by optimizing the Fisher divergence. For sake of simplicity, we consider the 1-D case.</p>
\[\begin{align}
&\frac12 \mathbb{E}_{p_\text{data}} \lVert \nabla_x \log p_\text{data} (x) - \nabla_x \log p_\theta (x) \rVert^2_2 \\
&= \frac12 \int p_\text{data} (x) \left( \nabla_x \log p_\text{data} (x) - \nabla_x \log p_\theta (x) \right)^2 \, dx \\
&= \frac12 \int p_\text{data}(x) (\nabla_x \log p_\text{data}(x))^2 \, dx + \frac12 \int p_\text{data} (x) (\nabla_x \log p_\theta (x))^2 \, dx \\
& - \int p_\text{data}(x) \nabla_x \log p_\text{data}(x) \nabla_x \log p_\theta (x) \, dx .
\end{align}\]
<p>The equalities simply follow from the integral definition of expectation. Note that the first term is simply a constant and can be ignored during optimization.</p>
<p>Applying integration by parts on the last term,</p>
\[\begin{align}
& \int p_\text{data}(x) \nabla_x \log p_\text{data}(x) \nabla_x \log p_\theta (x) \, dx \\
&= \int p_\text{data}(x) \frac{\nabla_x p_\text{data}(x)}{p_\text{data} (x)} \nabla_x \log p_\theta (x) \, dx \\
&= \int \nabla_x \log p_\theta (x) \nabla_x p_\text{data} (x) \, dx \\
&= p_\text{data}(x) \nabla_x \log p_\theta(x) \bigg|^\infty_{- \infty} - \int p_\text{data}(x) \nabla^2_x \log p_\theta (x) \, dx \\
& \approx - \mathbb{E}_{p_\text{data}}[\nabla^2_x \log p_\theta (x)].
\end{align}\]
<p>Putting all terms together,</p>
\[\begin{align}
&\frac12 \mathbb{E}_{p_\text{data}} \lVert \nabla_x \log p_\text{data} (x) - \nabla_x \log p_\theta (x) \rVert^2_2 \\
&= \mathbb{E}_{p_\text{data}}[\nabla^2_x \log p_\theta (x)] + \frac12 \mathbb{E}_{p_\text{data}} [(\nabla_x \log p_\theta (x))^2] + \text{const.} \\
&= \mathbb{E}_{p_\text{data}}[\nabla^2_x \log p_\theta (x) + \frac12 (\nabla_x \log p_\theta (x))^2] + \text{const.}
\end{align}\]
<p>We can easily extend this into a multidimensional context, the result of which is</p>
\[\mathbb{E}_{p_\text{data}} \left[\text{tr}(\nabla^2_\mathbf{x} \log p_\theta (\mathbf{x})) + \frac12 \lVert \nabla_\mathbf{x} \log p_\theta (\mathbf{x}) \rVert^2_2 \right] + \text{const.}\]
<h1 id="sliced-score-matching">Sliced Score-Matching</h1>
<p>We are specifically interested in instances where $f_\theta$ is parametrized as a neural network. Recall that</p>
\[\nabla_\mathbf{x} \log p_\theta (\mathbf{x}) = - \nabla_\mathbf{x} f_\theta (\mathbf{x}).\]
<p>Therefore, we can rewrite the score-matching objective as</p>
\[\mathbb{E}_{p_\text{data}} \left[\text{tr}(\nabla^2_\mathbf{x} f_\theta (\mathbf{x})) + \frac12 \lVert \nabla_\mathbf{x} f_\theta (\mathbf{x}) \rVert^2_2 \right] + \text{const}.\]
<p>While the first-order gradient can be simply obtained via backpropagation, $\text{tr}(\nabla^2<em>\mathbf{x} f</em>\theta (\mathbf{x}))$ is very computationally costly. To circumvent this problem, the authors propose random projection, which reduces dimensionality of data down to scalars. Quoting Yang Song:</p>
<blockquote>
<p>We propose <strong>sliced score matching</strong> to greatly scale up the computation of score matching. The motivating idea is that one dimensional data distribution is much easier to estimate for score matching. We propose to project the scores onto random directions, such that the vector fields of scores of the data and model distribution become scalar fields. We then compare the scalar fields to determine how far the model distribution is from the data distribution. It is clear to see that the two vector fields are equivalent if and only if their scalar fields corresponding to projections onto all directions are the same.</p>
</blockquote>
<p>The random projection version of Fisher divergence is</p>
\[\frac{1}{2}\mathbb{E}_{p_\text{data}}[(\mathbf{v}^\intercal \nabla_\mathbf{x} \log p_\text{data}(\mathbf{x}) - \mathbf{v}^\intercal \nabla_\mathbf{x} \log p_\theta(\mathbf{x}) )^2].\]
<p>Intuitively, the equation forces the two distributions to get closer according to some random projection $\mathbf{v}$. Since the projection is random, there exists a guarantee that optimizing this quantity will bring $p_\theta$ closer to the real data distribution.</p>
<p>The sliced score-matching objective under this revised Fischer divergence is</p>
\[\mathbb{E}_{p_\text{data}}\bigg[\mathbf{v}^\intercal \nabla_{\mathbf{x}}^2 \log p_\theta(\mathbf{x})\mathbf{v} + \frac{1}{2} (\mathbf{v}^\intercal\nabla_\mathbf{x} \log p_\theta(\mathbf{x}))^2 \bigg] + \text{const}.\]
<p>The problem has now been reduced into computationally tractable form.</p>
<p><em>This post was originally written in July, but polished into its current final form in December. If you spot any rough edges or details I missed, please feel free to reach out to me with corrections.</em></p>Jake TaeRecently, I’ve heard a lot about score-based networks. In this post, I will attempt to provide a high-level overview of what scores are and how the concept of score matching gives rise to a family of likelihood-based generative models. This post is heavily adapted from Yang Song’s post on sliced score matching.Flow Models2021-06-21T00:00:00+00:002021-06-21T00:00:00+00:00https://jaketae.github.io/study/flow<p>In this post, we will take a look at Flow models, which I’ve been obsessed with while reading papers like <a href="https://arxiv.org/abs/2005.11129">Glow-TTS</a> and <a href="https://arxiv.org/abs/2106.06103">VITS</a>. This post is heavily based on <a href="https://www.youtube.com/watch?v=JBb5sSC0JoY">this lecture video</a> by Pieter Abbeel, as well as the accompanied problem sets for the course, available <a href="https://github.com/rll/deepul/blob/master/homeworks/solutions/hw2_solutions.ipynb">here</a>.</p>
<h1 id="motivation">Motivation</h1>
<p>We want a model that satisfies the following:</p>
<ul>
<li>Simplifies complex, intractable distributions</li>
<li>Enables easy sampling and generation</li>
</ul>
<p>The two conditions are somewhat related in the sense that once you have a function (or a neural network that approximates such a function) that maps complex distributions to a tractable latent space, sampling can be performed immediately given that the mapping function is invertible. Invertibility is not something that can be easily assumed in deep learning and thus calls for some specific architectural decisions. Nonetheless, I find this formulation highly compelling and intuitive.</p>
<h1 id="change-of-variables">Change of Variables</h1>
<p>To fully understand the mechanics of flow, we need to first revisit the change of variables formula. Let $X$ denote a random variable, and $f_\theta$, some monotonic, invertible function that maps $X$ to a latent space $Z$. In the simplest case, $f_\theta$ might be the CDF of $X$, and $Z$ might be a uniform distribution $U(0, 1)$. More generally, we have</p>
\[z = f_\theta(x)\]
<p>Note that there exists a one-to-one correspondence between the two random variables, which is important to guarantee invertability.</p>
<p>Let $p(\cdot)$ denote the PDF of some random variable. Naively, one might think that</p>
\[p(x) \, dx = p(z) \, dz\]
<p>However, this fails to take into account the fact that a small change in $x$ may or may not be equally spread out in $z$ space. Hence, we need a correcting factor, which is the derivative of $z$ w.r.t. $x$.</p>
\[p(x) = p(z) \left\lvert \frac{\partial f_\theta(x)}{\partial x} \right\rvert
\tag{1}\]
<p>More formally, we can see this by considering the derivative of the CDF, which we will denote as $P(\cdot)$.</p>
\[\begin{align}
P(Z \leq z)
&= P(f_\theta(X) \leq z) \\
&= P(X \leq f_\theta^{-1}(z))
\end{align}
\tag{2}\]
<p>(2) holds if $f$ is a monotonically increasing function. If it is a monotonically decreasing function, then</p>
\[P(Z \leq z) = 1 - P(X \leq f_\theta^{-1}(z))\]
<p>Deriving both sides of the equation by $z$, we get</p>
\[\begin{align}
p(z)
&= \pm \, p(f_\theta^{-1}(z)) \frac{\partial f_\theta^{-1}(z)}{\partial z} \\
&= p(x) \left\lvert \frac{\partial x}{\partial z} \right\rvert \\
\end{align}
\tag{3}\]
<p>Rearranging (3) yields (1).</p>
<p>In a multi-dimensional context, the absolute value of the partial derivative term is effectively the determinant of the jacobian matrix.</p>
\[p(x) = p(z) \frac{\text{vol}(dz)}{\text{vol}(dx)} = p(z) \left\lvert \text{det} \frac{dz}{dx} \right\rvert\]
<p>We can understand the determinant of a matrix as calculating the magnitude of volume change that it would produce as a linear transformation of coordinates. We can see this as a multivariate analogue of slope or the gradient.</p>
<h1 id="training">Training</h1>
<p>Flow is nothing more than a neural network that models $f_\theta$. It takes a random variable living in some complex intractable space and sends it to a tractable dimension. In the case of normalizing flows, the target latent distribution is a normal distribution.</p>
<p>As is the case with any likelihood model, the goal is to fit a model that maximizes the log likelihood of data. Therefore, the objective is</p>
\[\max \sum_i \log p(x_i) \tag{4}\]
<p>We can substitute the likelihood with an expression using the latent transformed variable in (1). Then, (4) is equivalent to</p>
\[\max \sum_i \log p(f_\theta(x_i)) + \log \, \left\lvert \text{det} \frac{d f_\theta(x_i)}{d x} \right\rvert\]
<p>We train the flow model to minimize negative log likelihood, or equivalently, maximize log likelihood.</p>
<p>A few remarks:</p>
<ul>
<li>Notice that there is a jacobian sitting in the log likelihood term. This means that the flow model should model a function whose jacobian is easy to compute, which is usually not the case.</li>
<li>In a normalizing flow, $f_\theta$ will essentially try to assign as many points near the center of the Gaussian distribution in the vicinity of the mean.</li>
</ul>
<h1 id="perks-of-flow">Perks of Flow</h1>
<p>Up to this point, you might think that the flow model is a very intricate machinery that comes with many constraints, e.g. invertability, easy jacobian calculation, and etc. Nonetheless, I think it has some clear advantages in two aspects.</p>
<h2 id="sampling">Sampling</h2>
<p>To sample from a flow model, all we have to do is sample from the posterior distribution, such as a normal or Gaussian, then simply send it down an inverse flow.</p>
<h2 id="combinations">Combinations</h2>
<p>One salient characteristic of a flow is that a combination of flows is also a flow. If you have a set of invertible, differentiable functions, a stack of such functions will also be differentiable and invertible.</p>
\[z = f_k \circ f_{k - 1} \circ \cdots \circ f_1(x) \\
x = f_1^{-1} \circ f_2^{-1} \circ \cdots \circ f_k^{-1} (z)\]
<p>A capacity of a single flow layer is most likely limited, but a deep stack gives it enough expressional power to handle highly complex prior distributions.</p>
<h1 id="model-architecture">Model Architecture</h1>
<p>Flow models must be invertible, which leads to some important considerations when motivating their architecture. For instance, we cannot use ReLU activations since they violate the invertability requirement. Moreover, the jacobian should be easy to compute.</p>
<h2 id="inversion">Inversion</h2>
<p>The beautiful part of flow is that there is a simple way to resolve both conundrums: affine coupling layers. Let $d$ denote the cardinality of the embedding space on which we are applying a flow model. Then, the affine coupling layer can schematically be written as</p>
\[z_{1:d/2} = x_{1:d/2} \\
\begin{align}
z_{d/2:d}
&= x_{d/2:d} \odot s_\theta(x_{1:d/2}) + t_\theta(x_{1:d/2}) \\
&= x_{d/2:d} \odot s_\theta(z_{1:d/2}) + t_\theta(z_{1:d/2})
\end{align}
\tag{5}\]
<p>In plain language, we can consider $f_\theta$ as a special transformation in which the top half of $z$ is just copied from $x$ without modification. The bottom half undergoes an affine transformation, where the weights and biases are computed from the top half of $x$. We can easily check that this transformation is indeed invertible:</p>
\[x_{1:d/2} = z_{1:d/2} \\
x_{d/2:d} = s_\theta^{-1}(z_{1:d/2})(z_{d/2:d} - t_\theta(z_{1:d/2}))
\tag{6}\]
<p>Affine coupling layers are invertible only because the top half of $z$ is equal to that of $x$. This demystifies the copying operation in (5), which may have appeared somewhat unintuitive and awkward initially.</p>
<p>In practice, it appears that flow layers take a slightly more complicated form than the conceptual architecture detailed above. For example, <a href="https://arxiv.org/abs/1605.08803">Real NVP</a> proposes the following schema.</p>
\[z_{1:d/2} = x_{1:d/2} \\
h = a \times \text{tanh}(s_\theta(x_{1:d/2})) + b \\
z_{d/2:d} = \text{exp}(h) \times x_{d/2:d} + g_\theta(x_{1:d/2})\]
<p>where $a$ and $b$ are learned parameters, and $s_\theta$ and $g_\theta$ are some affine transformations, such as a multi-layer perceptron.</p>
<h2 id="jacobian">Jacobian</h2>
<p>Earlier, we noted that the determinant of the jacobian matrix must be easy to compute. This is a non-trivial constraint that does not hold true in many cases.</p>
<p>Fortunately, it turns out that the jacobian is very easy to compute given an affine coupling layer. We can somewhat intuit this by considering the copy-and-paste operation that is applied to the top half of the input. Given this operation, we can see that the the upper left quadrant of the jacobian will simply be an identity matrix.</p>
\[\begin{align}
\frac{\partial z}{\partial x}
&= \begin{pmatrix} \frac{\partial z_{1:d/2}}{\partial x_{1:d/2}} & \frac{\partial z_{1:2/d}}{\partial x_{2/d:d}} \\ \frac{\partial z_{2/d:d}}{\partial x_{1:2/d}} & \frac{\partial z_{d/2:d}}{\partial x_{d/2:d}} \end{pmatrix} \\
&= \begin{pmatrix} I & 0 \\ \frac{\partial z_{2/d:d}}{\partial x_{1:2/d}} & \text{diag}(s_\theta(x_{1:d/2})) \end{pmatrix}
\end{align}\]
<p>Although there are still complicated terms in the third quadrant of the jacobian, we do not have to consider them to compute the determinant of the jacobian: the determinant of a lower triangular matrix is simply the product of its diagonal entries. Hence, the determinant of the jacobian simply collapses to the product of the entries in the fourth quadrant. Hence, we see how the affine transform layer satisfies both the invertability and the jacobian determinant requirements.</p>
<h1 id="implementation">Implementation</h1>
<p>This is my attempt at a simple implementation of an affine transform layer. Although I could have combined the <code class="language-plaintext highlighter-rouge">forward()</code> and <code class="language-plaintext highlighter-rouge">inverse()</code> functions to remove duplicate lines of code, for clarity’s sake, I left them separate.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="k">class</span> <span class="nc">AffineCouplingLayer</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="n">half_size</span><span class="p">,</span> <span class="n">remainder</span> <span class="o">=</span> <span class="nb">divmod</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">remainder</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="k">print</span><span class="p">(</span>
<span class="sa">f</span><span class="s">"Expected `hidden_size` to be even, but received </span><span class="si">{</span><span class="n">hidden_size</span><span class="si">}</span><span class="s">"</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">half_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">inverse</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="k">if</span> <span class="n">inverse</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">inverse</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x1</span><span class="p">,</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">z1</span> <span class="o">=</span> <span class="n">x1</span>
<span class="n">s</span><span class="p">,</span> <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x1</span><span class="p">).</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">z2</span> <span class="o">=</span> <span class="n">x2</span> <span class="o">*</span> <span class="n">s</span> <span class="o">+</span> <span class="n">t</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">((</span><span class="n">z1</span><span class="p">,</span> <span class="n">z2</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">det</span> <span class="o">=</span> <span class="n">s</span><span class="p">.</span><span class="n">prod</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">).</span><span class="nb">abs</span><span class="p">()</span>
<span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="n">det</span>
<span class="k">def</span> <span class="nf">inverse</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span>
<span class="n">z1</span><span class="p">,</span> <span class="n">z2</span> <span class="o">=</span> <span class="n">z</span><span class="p">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x1</span> <span class="o">=</span> <span class="n">z1</span>
<span class="n">s</span><span class="p">,</span> <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">fc</span><span class="p">(</span><span class="n">z1</span><span class="p">).</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">x2</span> <span class="o">=</span> <span class="p">(</span><span class="n">z2</span> <span class="o">-</span> <span class="n">t</span><span class="p">)</span> <span class="o">/</span> <span class="n">s</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x</span>
</code></pre></div></div>
<p>This implementation is a close transcription of (5). <code class="language-plaintext highlighter-rouge">z1</code> denotes $z_{1:d/2}$; <code class="language-plaintext highlighter-rouge">z2</code>, $z_{d/2:d}$, and ditto the <code class="language-plaintext highlighter-rouge">x</code>s. The fully-connected layer <code class="language-plaintext highlighter-rouge">self.fc</code> acts as an affine transform. We condition the output <code class="language-plaintext highlighter-rouge">z2</code> on the result of the affine transform applied on <code class="language-plaintext highlighter-rouge">x1</code>. The <code class="language-plaintext highlighter-rouge">inverse()</code> is a transcription of (6).</p>
<p>We can perform a quick sanity check on this implementation by performing a forward pass, as well as an inverse path, and verifying that inverting the output of the forward pass recovers the original input.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">hidden_size</span> <span class="o">=</span> <span class="mi">10</span>
<span class="n">half_size</span> <span class="o">=</span> <span class="n">hidden_size</span> <span class="o">//</span> <span class="mi">2</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
<span class="n">l</span> <span class="o">=</span> <span class="n">AffineCouplingLayer</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">)</span>
<span class="n">z</span><span class="p">,</span> <span class="n">det</span> <span class="o">=</span> <span class="n">l</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">z</span><span class="p">.</span><span class="n">shape</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>torch.Size([8, 10])
</code></pre></div></div>
<p>We also get the determinant, which are scalar values. We get 8 values, which equals the batch size in the example input.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">det</span><span class="p">.</span><span class="n">shape</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>torch.Size([8])
</code></pre></div></div>
<p>We can check that the affine coupling layer only transforms the top half of the input.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">torch</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">x</span><span class="p">[:,:</span><span class="n">half_size</span><span class="p">],</span> <span class="n">z</span><span class="p">[:,:</span><span class="n">half_size</span><span class="p">])</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>True
</code></pre></div></div>
<p>Trivially, we can also verify that the rest of the output has been modified by the layer.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">torch</span><span class="p">.</span><span class="n">equal</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span><span class="n">half_size</span><span class="p">:],</span> <span class="n">z</span><span class="p">[:,</span><span class="n">half_size</span><span class="p">:])</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>False
</code></pre></div></div>
<p>Most importantly, we can see that the layer is indeed invertable; that is, it recovers the original input given the output of the layer <code class="language-plaintext highlighter-rouge">z</code>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">torch</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">l</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">inverse</span><span class="o">=</span><span class="bp">True</span><span class="p">))</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>True
</code></pre></div></div>
<p>We use <code class="language-plaintext highlighter-rouge">torch.allclose()</code> instead of <code class="language-plaintext highlighter-rouge">torch.equal()</code> due to floating point errors that can cause subtle changes in values. This is merely a technicality and does not affect the conclusion that affine coupling layers are fully invertable.</p>
<h1 id="conclusion">Conclusion</h1>
<p>In this post, we discussed flow models. I personally find flow-based models extremely interesting, simply because deep neural networks are normally not something that we can invert like a simple mathematical function. After all, the precise reason why we use deep neural networks is that we want to model complex non-linear functions. Flow models seem to go against this intuition in some sense, while providing us with the tools to handle highly complex data distributions to tractable posteriors.</p>
<p>I hope you enjoyed reading this post. Catch you up in the next one!</p>Jake TaeIn this post, we will take a look at Flow models, which I’ve been obsessed with while reading papers like Glow-TTS and VITS. This post is heavily based on this lecture video by Pieter Abbeel, as well as the accompanied problem sets for the course, available here.From ELBO to DDPM2021-05-17T00:00:00+00:002021-05-17T00:00:00+00:00https://jaketae.github.io/study/elbo<p>In this short post, we will take a look at variational lower bound, also referred to as the evidence lower bound or ELBO for short. While I have referenced ELBO in a <a href="https://jaketae.github.io/study/vae">previous blog post on VAEs</a>, the proofs and formulations presented in the post seems somewhat overly convoluted in retrospect. One might consider this a gentler, more refined recap on the topic. For the remainder of this post, I will use the terms “variational lower bound” and “ELBO” interchangeably to refer to the same concept. I was heavily inspired by <a href="https://www.youtube.com/watch?v=pStDscJh2Wo">Hugo Larochelle’s excellent lecture</a> on deep belief networks.</p>
<h1 id="concavity">Concavity</h1>
<p>One important property of the logarithm is that it is a concave function. A function $f$ is concave if it satisfies the following property:</p>
\[f\left( \sum \nolimits_i w_i x_i \right) \geq \sum \nolimits_i f(w_i x_i) \tag{1}\]
<p>In other words, if the function evaluated at some weighted sum of values is always greater or equal to the sum of the values evaluated by the function, the function is concave.</p>
<p>As a short detour, we discussed a similar concept in the context of variational autoencoders and Jenson’s inequality in an <a href="https://jaketae.github.io/study/vae/">earlier post</a>. In that post, I introduced the definition of convexity as follows:</p>
\[\mathbb{E}[f(x)] \geq f(\mathbb{E}[x]) \tag{2}\]
<p>While the notations used are slightly different, it is easy to see that the this definition is almost the exact reverse of (1). A trivial result of this is that a concave function is convex if and only if it is linear.</p>
<p>Given this understanding, we can now revisit the logarithm and quickly verify that it is a concave function.</p>
<h1 id="variational-lower-bound">Variational Lower Bound</h1>
<p>Before diving into a soup of equations, it’s important to remind ourselves of the problem setup. While ELBO is probably most commonly referenced in the context of variational autoencoders, I have recently seen it being mentioned in diffusion models as well. ELBO is a broad concept that can be applied to discuss any model with hidden latent representations, which we will denote as $h$ henceforth.</p>
<p>More concretely, given a model $p(x, h)$, we can write</p>
\[\begin{align}
\log p(x)
&= \log \left( \sum_{h} p(x, h) \right) \tag{2} \\
&= \log \left( \sum_{h} q(h \vert x) \frac{p(x, h)}{q(h \vert x)} \right) \tag{3} \\
& \geq \sum_{h} q(h \vert x) \log \frac{p(x, h)}{q(h \vert x)} \tag{4} \\
&= \sum_{h} q(h \vert x) \log p(x, h) - \sum_{h} q(h \vert x) \log q(h \vert x) \tag{5} \\
&= \mathbb{E}_q [\log p(x, h) - \log q(h \vert x)] \tag{6}
\end{align}\]
<p>(2) follows from the law of total probability, (3) is a simultaneous application of multiplication and division, (4) follows from the concavity of logarithms, (5) is an algebraic manipulation using the properties of logarithms, and (6) is a rewriting of the expression as an expectation under $q(h \vert x)$.</p>
<h2 id="equivalence-condition">Equivalence Condition</h2>
<p>In the formulation above, $q(h \vert x)$ can be understood as an approximation of a true distribution $p(h \vert x)$. Note that when $q(h \vert x) = p(h \vert x)$, we have an exact equality. Since</p>
\[\log p(x, h) = \log p(h \vert x) + \log p(x)\]
<p>We can substitute $q$ for $p$ and rewrite (5) as</p>
\[\begin{align}
\log p(x)
&= \sum_h p(h \vert x) (\log p(h \vert x) + \log p(x)) - \sum_h p(h \vert x) \log p(h \vert x) \\
&= \sum_h p(h \vert x) \log p(x)
\end{align}\]
<p>Since $p(x)$ does not depend on $h$, we can pull out the term from the summation, treating it as a constant, leaving us with</p>
\[\log p(x) \sum_h p(h \vert x)\]
<p>Using the law of total probability, we see that the summation totals to 1, leaving us with $\log p(x)$, which is what ELBO seeks to approximate.</p>
<p>Variational lower bounds are extremely useful when dealing with models whose interactions between $x$ and the hidden representation $h$ are complex, rendering (2) computationally intractable. Therefore, to train such models, we seek to maximize the log likelihood by pushing the lower bound up.</p>
<h2 id="kl-divergence">KL Divergence</h2>
<p>Recall the definition of KL divergence:</p>
\[\begin{align}
D_\text{KL}(q \parallel p)
&= \sum_{x \in X} q(x) \log \left( \frac{q(x)}{p(x)} \right) \\
&= - \sum_{x \in X} q(x) \log \left( \frac{p(x)}{q(x)} \right) \\
\end{align}\]
<p>We can see the resemblance between this definition and the definition of ELBO as written in (4), which was</p>
\[\log p(x) \geq \sum_{h} q(h \vert x) \log \frac{p(x, h)}{q(h \vert x)} \tag{4}\]
<p>The nice conclusion to this story is that</p>
\[\log p(x) - \text{ELBO} = D_\text{KL}(q(h \vert x) \parallel p(h \vert x)) \tag{7}\]
<p>This is a nice interpretation, since KL divergence is by definition always greater or equal to zero. Hence, we can confirm that</p>
\[\log p(x) \geq \text{ELBO}\]
<h3 id="proof">Proof</h3>
<p>In this section, we sketch a quick proof for (7).</p>
\[\begin{align}
D_\text{KL}(q(h \vert x) \parallel p(h \vert x))
&= \mathbb{E}_q [\log q(h \vert x) - \log p(h \vert x) ] \\
&= \mathbb{E}_q [\log q(h \vert x) - \log p(x, h) + \log p(x) ] \\
&= \mathbb{E}_q [\log q(h \vert x) - \log p(x, h)] + \log p(x) \\
\end{align}\]
<p>Notice that the expectation is the sign-flipped version ELBO term we derived above.</p>
\[\mathbb{E}_q [\log p(x, h) - q(h \vert x)] \tag{6}\]
<p>Therefore, we have</p>
\[D_\text{KL}(q(h \vert x) \parallel p(h \vert x)) = - \text{ELBO} + \log p(x) \\ \implies \log p(x) - \text{ELBO} = D_\text{KL}(q(h \vert x) \parallel p(h \vert x))\]
<h1 id="denoising-diffusion-probabilistic-models">Denoising Diffusion Probabilistic Models</h1>
<p>Since we have already seen how ELBO comes up in VAEs, it might be more helpful to take a look at another more recent example I came across while reading <a href="https://arxiv.org/abs/2006.11239">Denoising Diffusion Probabilistic Models</a>, or DDPM for short. The intent of this section is not to go over what DDPMs are, but rather to show a sneak peak into how ELBO is mentioned in the paper.</p>
<p>In the paper, the authors write</p>
<blockquote>
<p>Training is performed by optimizing the usual variational bound on negative log likelihood:
\(\begin{align}
\mathbb{E}[- \log p_\theta(\mathbf{x}_0)]
& \leq \mathbb{E}_q \left[ - \log \frac{p_\theta (\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T} \vert \mathbf{x}_0)} \right] \tag{8} \\
&= \mathbb{E}_q \left[ - \log p(\mathbf{x}_T) - \sum_{t \geq 1} \log \frac{p_\theta (\mathbf{x}_{t - 1} \vert \mathbf{x}_t)}{q(\mathbf{x}_t \vert \mathbf{x}_{t - 1})} \right] \tag{9} \\
& := L
\end{align}\)</p>
</blockquote>
<p>Equation tags have been added for the purposes of this post.</p>
<p>Admittedly, this does look confusing at first sight, but at its core is the definition of ELBO which we have derived in this post, plus some details inherent to DDPMs, such as Markov chain diffusion. In light of the topic of this post, I will attempt to give the simplest possible explanation of the later while focusing on the former.</p>
<p>To make things a little more familiar, let’s rewrite (6) to look more like the one presented in the DDPM paper.</p>
\[\begin{align}
\log p(x)
& \geq \mathbb{E}_q [\log p(x, h) - \log q(h \vert x)] \tag{6} \\
& \geq \mathbb{E}_q \left[ \log \frac{p(x, h)}{q(h \vert x)} \right] \tag{6-1} \\
\end{align}\]
<p>It is not difficult to see that simply flipping sign on both sides results in an expression that closely resembles (8). We also see a one-to-one correspondence between the variables used in this post and the ones in the paper. Namely, $\mathbf{x_0}$ corresponds to $x$, the ground-truth data, and $\mathbf{x}_t$ is the hidden representations of the model.</p>
<p>DDPMs work by starting out with some GT data $\mathbf{x}_0$, then gradually adding Gaussian noise through a Markov chain process. This gradually “breaks” signals originally present in the data, and send the ground truth data to an approximately isotropic distribution. This process is illustrated below. The figure was taken from the <a href="https://hojonathanho.github.io/diffusion/">author’s website</a>.</p>
<p><img src="https://hojonathanho.github.io/diffusion/assets/img/pgm_diagram_xarrow.png" /></p>
<p>A neural network is then trained to reverse this Markov chain process by recovering the original signal from the noise. The overall intuition is, in some sense, similar to that of GANs or VAEs, where a network learns to map latent dimensions to the data distribution. An obvious difference is that DDPMs iteratively recover the data, whereas GAN generators usually go directly to the data distribution. The slicing and summation notation in (9) exists precisely due to this iterative nature of the DDPM generative process.</p>
<h1 id="conclusion">Conclusion</h1>
<p>Topics like ELBO and KL divergence are one of those concepts that I always think I understand, but do not in reality. The mathematical details underlying those concepts are always intriguing to look at.</p>
<p>While this post in no way covers the entirety of the topic, I hope this will lay a solid foundation for those who want to better understand the mathematics behind latent variable models, such as variational autoencoders, DDPMs and the likes. Personally, I am starting to discover a newfound fascination for DDPMs, and hope to write more about them in the near future.</p>
<p>I hope you enjoyed reading this post. Catch you up in the next one!</p>Jake TaeIn this short post, we will take a look at variational lower bound, also referred to as the evidence lower bound or ELBO for short. While I have referenced ELBO in a previous blog post on VAEs, the proofs and formulations presented in the post seems somewhat overly convoluted in retrospect. One might consider this a gentler, more refined recap on the topic. For the remainder of this post, I will use the terms “variational lower bound” and “ELBO” interchangeably to refer to the same concept. I was heavily inspired by Hugo Larochelle’s excellent lecture on deep belief networks.Reboot2021-05-15T00:00:00+00:002021-05-15T00:00:00+00:00https://jaketae.github.io/blog/reboot<p>It has been a while since I last posted on this blog. Admittedly, a lot has happened in my life: I have been discharged from the Republic of Korea Army, received two full vaccination shots, and am now back home, meeting family and friends all of whom I have dearly missed during the 19-months of my military service. Of course, there are things that haven’t changed as well, such as the importance of this blog and my desire to continue documenting the interesting and random things that I learn every day.</p>
<p>Lately I’ve been realizing how powerful a force inertia is. It was easy to churn out posts every week when blogging was part of my personal norm, almost a habit if you will. Then, when perturbations were introduced to my life, I lost equilibrium and regrettably stopped writing on a regular basis. While I continued studying and committing to new and old repositories on my <a href="https://github.com/jaketae">GitHub</a>, for some inexplicable reason I found it difficult to restart something that I had stopped engaging with. Inertia is insidious, yet it concretizes with time, turning into a substance forceful enough to transform the definition of what personal norm entails.</p>
<p>Today, I was trying to wrap my head around the basics of stochastic differential equations and diffusion models (both of which I still do not understand) until I came across the term “score-based models.” The term “score” comes from Fischer’s score, which I had written about some time in the past. It’s an odd feeling when you realize that yourself a few months back was bright enough to understand concepts that the current self finds abstract and incomprehensible. But this wasn’t the only time I looked up something on my own blog. While there were also times when I spotted my own past mistakes, more often or not I found myself using my own writing as reference in an attempt to recall some concept or understanding from distant memory.</p>
<p>The conclusion of this admittedly verbose, ostensibly pointless post, is that documenting one’s intellectual journey is definitely a worthy endeavor. While the format of this post may appear as a self-promotion of sorts, the intended audience is really my future self, who I hope does not succumb to inertia or, put more bluntly, laziness. So here’s to another round of blogging!</p>Jake TaeIt has been a while since I last posted on this blog. Admittedly, a lot has happened in my life: I have been discharged from the Republic of Korea Army, received two full vaccination shots, and am now back home, meeting family and friends all of whom I have dearly missed during the 19-months of my military service. Of course, there are things that haven’t changed as well, such as the importance of this blog and my desire to continue documenting the interesting and random things that I learn every day.Linear Attention Computation in Nyströmformer2021-03-15T00:00:00+00:002021-03-15T00:00:00+00:00https://jaketae.github.io/study/nystrom-approximation<p>In this post, we will take a look at Nyström approximation, a technique that I came across in <a href="https://arxiv.org/pdf/2102.03902.pdf">Nyströmformer: A Nyström-based Algorithm for Approximating Self-Attention</a> by Xiong et al. This is yet another interesting paper that seeks to make the self-attention algorithm more efficient down to linear runtime. While there are many intricacies to the Nyström method, the goal of this post is to provide a high level intuition of how the method can be used to approximate large matrices, and how this method was used in the aforementioned paper.</p>
<h1 id="concept">Concept</h1>
<p>Despite its fancy and somewhat intimidating name, the Nyström method has an intuitive explanation. The idea is that, if we know the distance between point A and point B, as well as that between point B and point C, then we can approximate the distance between points A and C as some sort of addition of the two quantities. Of course, if we were discussing distances in the context of one-dimensional space, namely the real number line, we would not only be able to approximating the distance; we would know the exact quantity. However, in high-dimensional space, this is somewhat more difficult, and we can only resort to approximations.</p>
<p>To put things into context, let’s say we want to approximate the attention matrix in the transformer architecture. The Nyström method begins by selecting what the authors of the paper refer to as landmarks. Basically, if we have an attention matrix $A \in \mathbb{R}^{L \times L}$, then we select a few landmark rows and columns to use as the basis or pivot point for our approximation. The goal, then, is to select as few landmarks as possible while being able to approximate the attention matrix as accurately as possible.</p>
<p>For sake of simplicity, let’s say we select the first row and column to be our landmarks. Then, the goal is to approximate the inner sub-matrix $A_\text{sub} \in \mathbb{R}^{(L - 1) \times (L - 1)}$. How might we go about it?</p>
<p>As stated earlier, the intuition is that we use the landmarks as pivot points. Since we selected the first rows and columns as our landmarks, we have access to $q_1 k_n^\top \forall n \leq L$, as well as $q_n k_1\top \forall n \leq L$ (for simplicity, we ignore the normalizing square root). If we remind ourselves of the motivation behind the transformer’s key-value-query architecture, we can consider attention as a way of calculating the distance or relevance between pairs of tokens in a given sequence. Put differently, the landmarks tell us the distance between the first query and all other keys, as well as the distance between the first key and all other queries.</p>
<p>Without loss of generality, we can approximate the distance between any $i$th key and the $j$th query using these landmarks. The way we do this is somewhat similar to the point A, B, C example we briefly discussed earlier. Namely, we start by looking at the distance between the $i$th key and the first query. Then, we also look at the attention value between the first key and the $j$th query. Note that connecting the two dots kind of gives us a sense of how related the $i$th query and $j$ query are. To remove the redundancy, we divide the product by the self-attention of the first token, or the attention score between the first key and query.</p>
\[A_{ij} = \frac{q_i k_1^\top \cdot q_1 k_j^\top}{q_1 k_1^\top} \tag{1}\]
<p>Of course, if we have multiple landmarks, we can easily expand the expression above into matrix form. The tilde indicates landmark rows and columns.</p>
\[\tilde{A} = Q \tilde{K}^\top \times (\tilde{Q} \tilde{K}^\top)^\star \times \tilde{Q} K \tag{2}\]
<p>The star expression ($\star$) denotes the Moore-Penrose pseudo-inverse.</p>
<p>Now that we have a general intuition of how Nyström approximation works in the context of attention, let’s get into some basic implementation.</p>
<h1 id="implementation">Implementation</h1>
<p>The goal here is to see that Nyström approximation can indeed yield reasonably accurate results, and that the larger the number of key landmarks, the better the approximation. Consider this as a form of Monte Carlo experiment.</p>
<p>Let’s begin by importing some modules.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="o">%</span><span class="n">config</span> <span class="n">InlineBackend</span><span class="p">.</span><span class="n">figure_format</span><span class="o">=</span><span class="s">"retina"</span>
</code></pre></div></div>
<p>For sake of simplicity, we assume a very basic model with a hidden dimension of 2, and some data points whose sequence length is 5. For simplicity, we omit the batch dimension.</p>
<p>Then, in the context of attention, we would end up with the following keys and query tensors, as well as a five-by-five square attention matrix.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">d_model</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">seq_len</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">Q</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="n">K</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">@</span> <span class="n">K</span><span class="p">.</span><span class="n">T</span>
<span class="n">A</span><span class="p">.</span><span class="n">shape</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(5, 5)
</code></pre></div></div>
<p>The goal, then, is to approximate this square attention matrix.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array([[ 2.29571874, -0.7373519 , 0.32730778, -0.84730782, -1.16558083],
[ 1.4346883 , -0.32765206, 0.80095764, -0.39437617, 0.17889744],
[ 1.38973136, -0.61066937, -0.53783773, -0.67968999, -1.82523199],
[-1.80977456, 0.1036656 , -2.39735444, 0.18320197, -2.33569844],
[ 1.36516091, -0.40695455, 0.33580143, -0.47186895, -0.47836287]])
</code></pre></div></div>
<p>Let’s begin our approximation by assuming the worst case, in which we only have access to one landmark. This brings us to equation (1) where essentially all operations were done on vectors instead of matrices.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">num_landmarks</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">Q_tilde</span> <span class="o">=</span> <span class="n">Q</span><span class="p">[:</span><span class="n">num_landmarks</span><span class="p">]</span>
<span class="n">K_tilde</span> <span class="o">=</span> <span class="n">K</span><span class="p">[:</span><span class="n">num_landmarks</span><span class="p">]</span>
</code></pre></div></div>
<p>Recalling equations (1) and (2), we can now write the approximation of the attention matrix as follows.</p>
\[\tilde{A} = Q \tilde{K}^\top \times (\tilde{Q} \tilde{K}^\top)^\star \times \tilde{Q} K\]
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A_tilde</span> <span class="o">=</span> <span class="p">(</span><span class="n">Q</span> <span class="o">@</span> <span class="n">K_tilde</span><span class="p">.</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">pinv</span><span class="p">(</span><span class="n">Q_tilde</span> <span class="o">@</span> <span class="n">K_tilde</span><span class="p">.</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="p">(</span><span class="n">Q_tilde</span> <span class="o">@</span> <span class="n">K</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
<span class="n">A_tilde</span><span class="p">.</span><span class="n">shape</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(5, 5)
</code></pre></div></div>
<p>The dimensionality seems to match that of the original attention matrix, as expected. If we print out the approximation, we should expect to see exact matches in the first row and column; the rest of the four-by-four region of the matrix should roughly be similar to that of the original.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A_tilde</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array([[ 2.29571874, -0.7373519 , 0.32730778, -0.84730782, -1.16558083],
[ 1.4346883 , -0.46080128, 0.20454799, -0.52951722, -0.72841901],
[ 1.38973136, -0.44636176, 0.19813834, -0.51292444, -0.7055935 ],
[-1.80977456, 0.58127361, -0.25802521, 0.66795471, 0.91885757],
[ 1.36516091, -0.43847008, 0.19463525, -0.50385594, -0.69311861]])
</code></pre></div></div>
<p>We can indeed quickly verify that the first row and column are exact matches; however, the rest of the 16 elements are somewhat difficult to compare. We can more systematically calculate the difference between two matrices by using a norm, such as the Frobenius norm.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A</span> <span class="o">-</span> <span class="n">A_tilde</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>4.33185890598477
</code></pre></div></div>
<p>If we look at the raw value of the subtraction, we can see that the approximation isn’t too bad.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A</span> <span class="o">-</span> <span class="n">A_tilde</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>array([[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 0.00000000e+00],
[-2.22044605e-16, 1.33149223e-01, 5.96409654e-01,
1.35141056e-01, 9.07316456e-01],
[ 0.00000000e+00, -1.64307605e-01, -7.35976069e-01,
-1.66765549e-01, -1.11963848e+00],
[ 0.00000000e+00, -4.77608006e-01, -2.13932924e+00,
-4.84752738e-01, -3.25455600e+00],
[ 0.00000000e+00, 3.15155316e-02, 1.41166181e-01,
3.19869853e-02, 2.14755744e-01]])
</code></pre></div></div>
<h2 id="monte-carlo-approach">Monte Carlo Approach</h2>
<p>Let’s extend this little trial with one landmark to larger matrices. For ease of execution and implementation, I’ve basically wrapped each step outlined above as functions.</p>
<p>The first function, <code class="language-plaintext highlighter-rouge">norms_by_landmarks</code>, receives query and key matrices, then approximates the attention matrix while varying the number of landmarks. The Frobenius norm is used to measure how good the approximation is. Theoretically, we should expect to see a downward-sloping pattern.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">norms_by_landmarks</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">):</span>
<span class="n">result</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">A</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">@</span> <span class="n">K</span><span class="p">.</span><span class="n">T</span>
<span class="k">for</span> <span class="n">num_landmarks</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">Q</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">Q_tilde</span> <span class="o">=</span> <span class="n">Q</span><span class="p">[:</span><span class="n">num_landmarks</span><span class="p">]</span>
<span class="n">K_tilde</span> <span class="o">=</span> <span class="n">K</span><span class="p">[:</span><span class="n">num_landmarks</span><span class="p">]</span>
<span class="n">A_tilde</span> <span class="o">=</span> <span class="p">(</span><span class="n">Q</span> <span class="o">@</span> <span class="n">K_tilde</span><span class="p">.</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">pinv</span><span class="p">(</span><span class="n">Q_tilde</span> <span class="o">@</span> <span class="n">K_tilde</span><span class="p">.</span><span class="n">T</span><span class="p">)</span> <span class="o">@</span> <span class="p">(</span><span class="n">Q_tilde</span> <span class="o">@</span> <span class="n">K</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>
<span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="n">norm</span><span class="p">(</span><span class="n">A</span> <span class="o">-</span> <span class="n">A_tilde</span><span class="p">))</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">result</span><span class="p">)</span>
</code></pre></div></div>
<p>The second function, <code class="language-plaintext highlighter-rouge">run_experiment</code>, is a wrapper around the first one. It repeatedly conducts the same experiment for a specified number of iterations. The purpose of repetition is essentially remove the possibility of luck, where some randomly initialized key and query matrices are configured in such a way that the Nyström approximation performs too well or poorly on a given task. By repeating the experiment and averaging the result—which is the spirit behind Monte Carlo approximations—we can have more confidence in our final result.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">run_experiments</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">num_iter</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
<span class="n">result</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_iter</span><span class="p">):</span>
<span class="n">Q</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="n">K</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="n">norm</span> <span class="o">=</span> <span class="n">norms_by_landmarks</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">)</span>
<span class="n">result</span> <span class="o">+=</span> <span class="n">norm</span>
<span class="k">return</span> <span class="n">result</span> <span class="o">/</span> <span class="n">num_iter</span>
</code></pre></div></div>
<p>Here, we assume a sequence length of 50, and the hidden size of the model (or the embedding size) to be 10. And off we go!</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">norms</span> <span class="o">=</span> <span class="n">run_experiments</span><span class="p">(</span><span class="n">d_model</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">seq_len</span><span class="o">=</span><span class="mi">50</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">norms</span><span class="p">)),</span> <span class="n">norms</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<p> <br />
<img src="/assets/images/2021-03-15-nystrom-approximation_files/2021-03-15-nystrom-approximation_30_0.png" />
</p>
<h1 id="conclusion">Conclusion</h1>
<p>While there is some noise in the final outcome, we do see that beyond a certain dimension, the approximation yields near exact results. In this case, it seems to happen around 10 landmarks.</p>
<p>Transformers have now taken over much of the ML world, even beyond NLP. Recently, I came across a paper titled <a href="https://arxiv.org/abs/2103.05247">Pretrained Transformers are Universal Computation Engines</a>. Apparently, pretrained transformer LMs can perform extremely well on tasks with minimal fine-tuning. Specifically, even if the feedforward and attention portion of the network frozen—which amounts to nearly 99 percent of the entire model architecture—transformer LMs can be micro-tuned to a wide array of tasks that are even not specifically NLP-related.</p>
<p>While there is certainly a possibility that a new SOTA model architecture will be announced by researchers in the new future, similar to how transformers made LSTMs obsolete in many fields, I think transformers are here to stay around for longer. And it’s certainly interesting to see attempts to make it even better, lighter, and faster. Nyströmformer was one such attempt, and I hope to see more.</p>Jake TaeIn this post, we will take a look at Nyström approximation, a technique that I came across in Nyströmformer: A Nyström-based Algorithm for Approximating Self-Attention by Xiong et al. This is yet another interesting paper that seeks to make the self-attention algorithm more efficient down to linear runtime. While there are many intricacies to the Nyström method, the goal of this post is to provide a high level intuition of how the method can be used to approximate large matrices, and how this method was used in the aforementioned paper.Relative Positional Encoding2021-03-01T00:00:00+00:002021-03-01T00:00:00+00:00https://jaketae.github.io/study/relative-positional-encoding<p>In this post, we will take a look at relative positional encoding, as introduced in <a href="https://arxiv.org/pdf/1803.02155.pdf">Shaw et al (2018)</a> and refined by <a href="https://arxiv.org/pdf/1809.04281.pdf">Huang et al (2018)</a>. This is a topic I meant to explore earlier, but only recently was I able to really force myself to dive into this concept as I started reading about music generation with NLP language models. This is a separate topic for another post of its own, so let’s not get distracted.</p>
<p>Let’s dive right into it!</p>
<h1 id="concept">Concept</h1>
<p>If you’re already familiar with transformers, you probably know that transformers process inputs in parallel at once. This is one of the many reasons why transformers have been immensely more successful than RNNs: RNNs are unable to factor in long-range dependencies due to their recurrent structure, whereas transformers do not have this problem since they can see the entire sequence as it is being processed. However, this also means that transformers require positional encodings to inform the model about where specific tokens are located in the context of a full sequence. Otherwise, transformer would be entirely invariant to sequential information, considering “John likes cats” and “Cats like John” as identical. Hence, positional encodings are used to signal the absolute position of each token.</p>
<h2 id="relative-positional-encoding">Relative Positional Encoding</h2>
<p>While absolute positional encodings work reasonably well, there have also been efforts to exploit pairwise, relative positional information. In <a href="https://arxiv.org/pdf/1803.02155.pdf">Self-Attention with Relative Position Representations</a>, Shaw et al. introduced a way of using pairwise distances as a way of creating positional encodings.</p>
<p>There are a number of reasons why we might want to use relative positional encodings instead of absolute ones. First, using absolute positional information necessarily means that there is a limit to the number of tokens a model can process. Say a language model can only encode up to 1024 positions. This necessarily means that any sequence longer than 1024 tokens cannot be processed by the model. Using relative pairwise distances can more gracefully solve this problem, though not without limitations. Relative positional encodings can generalize to sequences of unseen lengths, since theoretically the only information it encodes is the relative pairwise distance between two tokens.</p>
<p>Relative positional information is supplied to the model on two levels: values and keys. This becomes apparent in the two modified self-attention equations shown below. First, relative positional information is supplied to the model as an additional component to the keys.</p>
\[e_{ij} = \frac{x_i W^Q (x_j W^K + a_{ij}^K)^\top}{\sqrt{d_z}} \tag{1}\]
<p>The softmax operation remains unchanged from vanilla self-attention.</p>
\[\alpha_{ij} = \frac{\text{exp} \space e_{ij}}{\sum_{k = 1}^n \text{exp} \space e_{ik}}\]
<p>Lastly, relative positional information is supplied again as a sub-component of the values matrix.</p>
\[z_i = \sum_{j = 1}^n \alpha_{ij} (x_j W^V + a_{ij}^V) \tag{2}\]
<p>In other words, instead of simply combining semantic embeddings with absolute positional ones, relative positional information is added to keys and values on the fly during attention calculation.</p>
<h2 id="bridging-shaw-and-huang">Bridging Shaw and Huang</h2>
<p>In Huang et al., also known as the music transformer paper, the authors pointed out that calculating relative positional encodings as introduced in Shaw et al. requires $O(L^2D)$ memory due to the introduction of an additional relative positional encoding matrix. Here, $L$ denotes the length of the sequence, and $D$, the hidden state dimension used by the model. Huang et al. introduced a new way of computing relative positional encoding via a clever skewing operation.</p>
<p>To cut to the chase, below is the relative attention mechanism suggested by the authors in Huang et al.</p>
\[\text{RelativeAttention} = \text{Softmax} \left( \frac{Q K^\top + S_{rel}}{\sqrt{D_h}} \right) V \tag{3}\]
<p>It seems that in the music transformer paper, the authors dropped the additional relative positional embedding that corresponds to the value term and focus only on the key component. In other words, the authors only focus on (1), not (2).</p>
<p>The notations in (1), (2), and (3) were each borrowed verbatim from the authors of both papers. Hence, there is some notational mixup that requires attention. Specifically, $S^{rel}$ in the music transformer paper is simply</p>
\[S_{rel} = Q R^\top\]
<p>where</p>
\[R_{ij} = a_{ij}^K\]
<p>In other words, (3) is just an expanded variant of (1).</p>
<p>To make things a little clearer, let’s review the dimensions of each tensor. First, from vanilla self-attention, we know that $Q \in \mathbb{R}^{H \times L \times D_h}$, where $H$ denotes the number of heads. Thus, $R \in \mathbb{R}^{H \times L \times D_h}$, and $S_{rel} \in \mathbb{R}^{H \times L \times L}$. $R$ is a matrix of relative positional embeddings. Intuitively, $R$ can also be understood as the result of passing a matrix of relative positional indices through an embedding layer. For concreteness, here is a dummy function that creates relative positional indices.</p>
<h2 id="efficient-computation">Efficient Computation</h2>
<p>The skewing mechanism introduced in Huang et al., is ingenious, but it isn’t black magic. The technique could roughly be understood as a set of clever padding and matrix manipulation operations that ultimately result in $S_{rel}$ without explicitly creating or computing $R$. The reason why we might want to avoid calculating $R$ is that it is a huge memory bottleneck, as the matrix requires $O(L^2 d)$ extra space.</p>
<p>The method presented by Huang et al. could be seen as follows:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">relative_positions</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
<span class="n">result</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">):</span>
<span class="n">front</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">end</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span> <span class="o">-</span> <span class="n">i</span><span class="p">))</span>
<span class="n">result</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">front</span> <span class="o">+</span> <span class="n">end</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span>
</code></pre></div></div>
<p>Let’s see what the indices look like for a sequence of length five.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">relative_positions</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>[[0, 1, 2, 3, 4],
[-1, 0, 1, 2, 3],
[-2, -1, 0, 1, 2],
[-3, -2, -1, 0, 1],
[-4, -3, -2, -1, 0]]
</code></pre></div></div>
<p>We can understand each row as indicating the current position of attention, and each index as representing the distance between the current token and the token corresponding to the index. A quick disclaimer that this example does not strictly follow the details outlined in Shaw et al. For instance, this function does not take into account $k$, or the width of the window. The 0-based indexing scheme is also from Huang et al.
These minor details notwithstanding, having a clear sense of what $R$ is, I think, is very helpful in understanding relative attention, as well as the skewing mechanism introduced in Huang et al. For a fuller explanation of these concepts, I highly recommend <a href="https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a">this medium article</a>.</p>
<p>Below is a visual summary of the skewing mechanism.</p>
<p><img src="/assets/images/relative_attn_skewing.png" /></p>
<p>Personally, I found this diagram to be a bit confusing at first. However, with must staring and imagination, I slowly started to realize that the skewing is simply a way of transforming $QE_r^\top$ into $QR^\top$, where $E_r$ is the relative positional embedding matrix.</p>
<p>Instead of trying to explain this in plain text, I decided that implementing the the entire relative global attention would not only help with demonstration, but also cementing my own understanding of how this works.</p>
<h1 id="implementation">Implementation</h1>
<p>This implementation of relative global attention was in large part influenced by Karpathy’s <a href="https://github.com/karpathy/minGPT">minGPT</a>, which we discussed in <a href="https://jaketae.github.io/study/gpt/">this previous post</a>, as well as Prayag Chatha’s implementation of the music transformer, available on GitHub <a href="https://github.com/chathasphere/pno-ai">here</a>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>
</code></pre></div></div>
<p>Below is a simple implementation of a relative global attention layer. I’ve deviated from Chatha’s implementation in a number of ways, but the most important and probably worth mentioning is how I treat the relative positional embedding matrix. In Shaw et al., the authors note that “[relative positional embeddings] can be shared across attention heads.” Hence, I’m using one <code class="language-plaintext highlighter-rouge">Er</code> matrix to handle all heads, instead of creating multiple of them. This matrix is registered as a <code class="language-plaintext highlighter-rouge">nn.Parameter</code>.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">RelativeGlobalAttention</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">max_len</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span>
<span class="nb">super</span><span class="p">().</span><span class="n">__init__</span><span class="p">()</span>
<span class="n">d_head</span><span class="p">,</span> <span class="n">remainder</span> <span class="o">=</span> <span class="nb">divmod</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span>
<span class="k">if</span> <span class="n">remainder</span><span class="p">:</span>
<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span>
<span class="s">"incompatible `d_model` and `num_heads`"</span>
<span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">max_len</span> <span class="o">=</span> <span class="n">max_len</span>
<span class="bp">self</span><span class="p">.</span><span class="n">d_model</span> <span class="o">=</span> <span class="n">d_model</span>
<span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
<span class="bp">self</span><span class="p">.</span><span class="n">key</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">value</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">query</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
<span class="bp">self</span><span class="p">.</span><span class="n">Er</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">d_head</span><span class="p">))</span>
<span class="bp">self</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span>
<span class="s">"mask"</span><span class="p">,</span>
<span class="n">torch</span><span class="p">.</span><span class="n">tril</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="n">max_len</span><span class="p">,</span> <span class="n">max_len</span><span class="p">))</span>
<span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">).</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="p">)</span>
<span class="c1"># self.mask.shape = (1, 1, max_len, max_len)
</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="c1"># x.shape == (batch_size, seq_len, d_model)
</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">shape</span>
<span class="k">if</span> <span class="n">seq_len</span> <span class="o">></span> <span class="bp">self</span><span class="p">.</span><span class="n">max_len</span><span class="p">:</span>
<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span>
<span class="s">"sequence length exceeds model capacity"</span>
<span class="p">)</span>
<span class="n">k_t</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">key</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">).</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># k_t.shape = (batch_size, num_heads, d_head, seq_len)
</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">value</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">q</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">query</span><span class="p">(</span><span class="n">x</span><span class="p">).</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">).</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="c1"># shape = (batch_size, num_heads, seq_len, d_head)
</span>
<span class="n">start</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">max_len</span> <span class="o">-</span> <span class="n">seq_len</span>
<span class="n">Er_t</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">Er</span><span class="p">[</span><span class="n">start</span><span class="p">:,</span> <span class="p">:].</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="c1"># Er_t.shape = (d_head, seq_len)
</span> <span class="n">QEr</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">Er_t</span><span class="p">)</span>
<span class="c1"># QEr.shape = (batch_size, num_heads, seq_len, seq_len)
</span> <span class="n">Srel</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">skew</span><span class="p">(</span><span class="n">QEr</span><span class="p">)</span>
<span class="c1"># Srel.shape = (batch_size, num_heads, seq_len, seq_len)
</span>
<span class="n">QK_t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">k_t</span><span class="p">)</span>
<span class="c1"># QK_t.shape = (batch_size, num_heads, seq_len, seq_len)
</span> <span class="n">attn</span> <span class="o">=</span> <span class="p">(</span><span class="n">QK_t</span> <span class="o">+</span> <span class="n">Srel</span><span class="p">)</span> <span class="o">/</span> <span class="n">math</span><span class="p">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">q</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="n">mask</span> <span class="o">=</span> <span class="bp">self</span><span class="p">.</span><span class="n">mask</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="n">seq_len</span><span class="p">,</span> <span class="p">:</span><span class="n">seq_len</span><span class="p">]</span>
<span class="c1"># mask.shape = (1, 1, seq_len, seq_len)
</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="p">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="s">"-inf"</span><span class="p">))</span>
<span class="c1"># attn.shape = (batch_size, num_heads, seq_len, seq_len)
</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">attn</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span>
<span class="c1"># out.shape = (batch_size, num_heads, seq_len, d_head)
</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="c1"># out.shape == (batch_size, seq_len, num_heads, d_head)
</span> <span class="n">out</span> <span class="o">=</span> <span class="n">out</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># out.shape == (batch_size, seq_len, d_model)
</span> <span class="k">return</span> <span class="bp">self</span><span class="p">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">out</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">skew</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">QEr</span><span class="p">):</span>
<span class="c1"># QEr.shape = (batch_size, num_heads, seq_len, seq_len)
</span> <span class="n">padded</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">pad</span><span class="p">(</span><span class="n">QEr</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="c1"># padded.shape = (batch_size, num_heads, seq_len, 1 + seq_len)
</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span> <span class="o">=</span> <span class="n">padded</span><span class="p">.</span><span class="n">shape</span>
<span class="n">reshaped</span> <span class="o">=</span> <span class="n">padded</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">num_cols</span><span class="p">,</span> <span class="n">num_rows</span><span class="p">)</span>
<span class="c1"># reshaped.size = (batch_size, num_heads, 1 + seq_len, seq_len)
</span> <span class="n">Srel</span> <span class="o">=</span> <span class="n">reshaped</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">:,</span> <span class="p">:]</span>
<span class="c1"># Srel.shape = (batch_size, num_heads, seq_len, seq_len)
</span> <span class="k">return</span> <span class="n">Srel</span>
</code></pre></div></div>
<p>Much of the operations in <code class="language-plaintext highlighter-rouge">forward</code> method are code translations of the equations we discussed above. The interesting bit happens in the <code class="language-plaintext highlighter-rouge">skew</code> method. Basically, we pad $Q E_r^\top$ to the left, then reshape to shift all indices, then slice out the necessary portion of the matrix to obtain $Q R^\top$, or $S_{rel}$. This has the benefit of reducing the memory requirement; since we don’t have to calculate $R$ and can instead directly use $E_r$, which is a matrix that is needed anyway, the memory requirement is reduced to $O(Ld)$. This is what I personally think is one of the biggest contributions of Huang et al.</p>
<p>Let’s quickly check that the layer works as intended by quickly performing a basic tensor shape check.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">8</span>
<span class="n">seq_len</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">d_model</span> <span class="o">=</span> <span class="mi">768</span>
<span class="n">num_heads</span> <span class="o">=</span> <span class="mi">12</span>
<span class="n">test_in</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">)</span>
<span class="n">l</span> <span class="o">=</span> <span class="n">RelativeGlobalAttention</span><span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)</span>
<span class="n">l</span><span class="p">(</span><span class="n">test_in</span><span class="p">).</span><span class="n">shape</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>torch.Size([8, 100, 768])
</code></pre></div></div>
<p>We get an output of size <code class="language-plaintext highlighter-rouge">(batch_size, seq_len, d_model)</code>, which is what we expect.</p>
<h1 id="conclusion">Conclusion</h1>
<p>In this post, we discussed relative positional encoding as introduced in Shaw et al., and saw how Huang et al. was able to improve this algorithm by introducing optimizations.</p>
<p>Relative positional encodings were used in other architectures, such as Transformer XL, and more recently, DeBERTa, which I also plan on reviewing soon. Relative positioning is probably a lot closer to how we humans read text. While it is probably not a good idea to always compare and conflate model architectures with how the human brain works, I still think it’s an interesting way to think about these concepts.</p>
<p>This post was also a healthy exercise in that it really forced me to try to understand every single detail. Every sentence and diagram can be of huge help when you are trying to actually implement ideas that are outlined in published papers. I could see why <a href="https://paperswithcode.com">Papers with Code</a> became such a huge thing. It’s always helpful to see actual implementations and, even better, reproducible results. In this particular post, referencing music transformer implementations on GitHub and re-reading the paper many times really helped me nail down points that were initially confusing or unclear.</p>
<p>I hope you’ve enjoyed reading this post. Catch you up in the next one!</p>Jake TaeIn this post, we will take a look at relative positional encoding, as introduced in Shaw et al (2018) and refined by Huang et al (2018). This is a topic I meant to explore earlier, but only recently was I able to really force myself to dive into this concept as I started reading about music generation with NLP language models. This is a separate topic for another post of its own, so let’s not get distracted.