from Guide to Machine Learning on Jul 9, 2023
How Large Language Model training works
Large Language Models are typically trained in a multi-stage pipeline with differing metrics, tasks, and datasets for each step. In this post, we'll introduce the basics of training and breakdown these different stages.
Training a Large Language Model (LLM) requires several steps, which we'll walk through in this post. At a high level, there are two main steps:
- Pretraining, in the context of LLMs1, involves training on a large, unlabeled dataset. The goal of pretraining is to learn a general language representation that can be used for a variety of downstream tasks. For example, a model pretrained on a large corpus of English text can be used for tasks like text generation, sentiment analysis, and question answering.
- Finetuning involves taking a pretrained model and training it on a smaller, labeled dataset. The goal of finetuning is to adapt the pretrained model to a specific task. For example, a model may be finetuned to generate text in the style of a particular author.
Throughout this lesson, we'll focus on practical tips for each of these steps, covering key ingredients for each phase: datasets, hyper-parameters, and metrics. We'll also discuss some of the challenges that arise when training LLMs, along with recent advances to address these challenges.
The primary intuition behind pre-training is simple: We have relatively little amounts of labeled data, but we have vast amounts of unlabeled data on the internet; how can we leverage this large corpus of unlabeled text to build better natural language models? The solution in deep learning in general is to train your model using self-supervised learning.
In self-supervised learning, the model is trained to predict some attribute of the input data. These attributes provide "supervision" to the model, but they are not provided by humans. Instead, they are generated automatically from the data. For example:
- In the context of computer vision, a model may be trained to predict the rotation of a randomly-rotated image2. The idea is that rotation can only be predicted correctly if the model builds a visual understanding of different classes: that horses stand right-side up and that boats sit in water and not the sky. Constructing a dataset is simple. Take a large collection of images, and randomly rotate each of them. The model is then trained to accept the randomly-rotated image and predict the rotation you pre-determined.
- In the context of natural language (e.g., for LLMs), a model may be trained to predict the next word in a sentence. The idea is similar – the next word can be predicted correctly if the model builds an understanding of natural language, for example that the "quick brown fox" can jump over lazy dogs, fences, and picnic tables3. Constructing a dataset is also simple. Take a large collection of text, and for every subset of text, ask the model to predict the next word.
In both cases, the model is trained to predict some part of the input data, and no human annotation effort is required. In this sense, the model has an unlimited supply of training data.
The above examples are just hand-selected for ease of explanation, but there are a variety of different pre-training tasks available for LLMs to leverage4. The two most popular are Masked Language Modeling (MLM) and Causal Language Modeling (CLM):
- Causal Language Modeling (CLM): This is more straightforwardly referred to as next-word prediction, as well. Namely, the model is trained to predict the next word. For example, given the phrase "the quick brown", the model may be asked to predict the next word "fox". This is done repeatedly: The model then takes the context "the quick brown fox" and predicts "jumps"; then takes "the quick brown fox jumps" and predicts "over"; and so on. In this way, the model can predict entire sequences of words, despite only being trained to predict one word at a time. This is generally used by most LLMs.
- Masked Language Modeling (MLM): The model is trained to predict a masked word in a sentence5. For example, given the sentence "the quick brown fox jumps over the lazy dog", the model may be asked to predict the word "fox" given the context "the quick brown [MASK] jumps over the lazy dog". This was formerly used by popular natural language backbones such as BERT and RoBERTa, but has since largely been replaced by CLM.
These terms are useful for understanding the different types of pre-training tasks but they are not mutually exclusive. For example, the model may be trained to predict the next word in a sentence, but with some probability, the model is asked to predict a masked word instead. In this way, the model is trained to predict both the next word and masked words. This idea of a model being trained to predict multiple attributes of the input data is called multi-task learning, first introduced to LLMs by T5.
The sources of data for pre-training LLMs are relatively standard, with the most popular sources including the following:
- Wikipedia: This encyclopedia that needs no introduction offers 21 GB of text across 6.7 million articles.
- StackExchange: The most popular subsite, StackOverflow, offers around 60 GB of data across 24 million questions.
- Github: This host for open-source code offers over 21 TB of data for 372 million repositories.
- CommonCrawl: A large repository of web crawl data. The latest crawl in October 2022 includes 380 tebibytes of data from 3.15 billion webpages.
- arXiv: A repository of over 2.2 million research papers in physics, mathematics, computer science, and other fields.
Unlike the data sources, curated datasets stemming from these sources are vast and constantly growing. This is in due in part to the realization that data quality, rather than data quantity or model size, was the bottleneck for LLM quality. In light of this, one of the first datasets C4 was curated from CommonCrawl. Since then, a variety of different datasets have been released, sorted from oldest to newest:
- C4: A curated version of CommonCrawl that includes 305 GB of curated text data. This dataset is available for download here.
- The Pile: 825 GiB of text data pooled from largely academic sources of data. This dataset is available for download here.
- MassiveText: 10.5 TB of text data across 2.35 billion documents curated from the above sources – C4, Wikipedia, and Github to name a few. This dataset is closed-source so not available for download, but it's worth mentioning as a large number of models were previously trained on this dataset.
- RedPajama: A 1.2 trillion token dataset curated by following the data collection recipe provided by LLaMA. This dataset is available for download here.
- SlimPajama: A 627 billion token dataset curated from RedPajama largely by deduplicating. This dataset is available for download here.
There are many more datasets, but for the purposes of this discussion, we will focus on the above datasets as they are the most popular and most widely used.
At this stage, the model has not yet been trained to perform any specific task, so the metrics are geared towards general "language understanding" – more specifically, how well the model has modeled the data. Notably, these metrics are all a function of the model's outputted logits, not the sequences it generates. There are several metrics that aim to achieve this:
- Perplexity: The most common metric is perplexity, which is a measure of how "surprised" the model is to see a particular sentence. Ideally, sentences from our training distribution (i.e., realistic sentences) should have low perplexity, meaning the model determines these sentences are "typical", while sentences from a different distribution (i.e., nonsense collections of words) should have high perplexity. As a result, lower perplexity is better.
- Bilingual Evaluation Understudy (BLEU): This metric is commonly used in machine translation to measure how well the model's output matches the ground truth. More generally, BLEU measures how similar predicted text is to a set of reference texts. This similarity value can range from 0 to 1, with 1 indicating the predicted text is identical to one of the reference texts. As a result, higher BLEU score is better. Note this metric does not take in account grammar or meaning, but it measures how well the model has memorized the training data. There are variants such as ROUGE and METEOR that are similar to BLEU.
- Zero-shot Classification Accuracy: Generally, "zero-shot" refers to any task where the model is asked to make predictions for data it was not trained on. Specifically in the LLM context, the model is asked to rank the probability of two possible sentences. The dataset includes a label for which output human annotators prefer. The model's accuracy is then measured by how often it correctly predicts the human-preferred sentence. This metric is useful for measuring how well the model has learned to model language that humans prefer. As with any other classification task, higher accuracy is better.
For the most part, metrics which involve no annotation are more common, especially on larger corpuses. This includes perplexity, BLEU and all of their variants. On the other hand, accuracy is also standard to report on smaller datasets such as winogrande, super-glue, and glue.
Finetuning is the process of taking a pretrained model and training it on a specific task.
- This may involve a more specific dataset. For example, a model pretrained on a large corpus of text can be finetuned on a dataset of movie reviews to better model and detect abnormal reviews. In this case, the model is pretrained on a large corpus of generic text such as Wikipedia, and finetuned on a smaller dataset of movie reviews from IMDB, for example.
- This may involve dataset with more specific task. In short, the pretrained model is trained to model language, meaning the model can simply complete text. This is a far cry from the LLMs that we interact with – asking questions, generating summaries, or other instructing the model to perform a language task. To bridge this gap, most LLMs are finetuned using instruction tuning. This is a process we will describe in more detail below; in short, the model is further trained to follow instructions rather than just complete text.
There are a number of different finetuning objectives, but for the purposes of this discussion, we will focus on instruction tuning as it is the most popular and most widely used.
Instruction tuning involves finetuning on natural language datasets, where the input text contains instructions for the task at hand. For example, the input text may be "Is this review positive or negative? Review: I absolutely loved the french fries." There many tasks that a model can be instruction tuned on, such as sentiment analysis as we showed above.
One approach is to simply enumerate these tasks and exhaustively collect datasets to instruction tune your model on, which is what FLAN more or less did – simply collect a large number of smaller-scale datasets that contain instructions and desired outputs. This ranged from the tasks mentioned above to general question-answering, reading comprehension, machine translation, and more.
However, this approach is not scalable, as it requires a large amount of human effort to collect these datasets. To address this, OpenAI proposed a general technique called RLHF to directly model and scalably pseudo-label datasets with human feedback.
Reinforcement Learning with Human Feedback (RLHF6) is a finetuning objective that uses human feedback to train the model. There are three main steps in this process:
- Human Feedback: The LLM generates a sequence of text, and a human annotator provides feedback on the quality of the generated text. This feedback can be binary (e.g., thumbs up or thumbs down) or a score (e.g., 1-5 stars). At this point, we have dataset of
(text, human feedback)pairs.
- Reward Function: This reward function is a separate model that takes in a sequence of text and outputs a score. In particular, this reward function is trained on the human feedback that we discussed above – the
(text, human feedback)pairs. In this sense, the reward function learns to predict human preferences.
- Policy Optimization: The LLM is then trained to maximize the reward function. This is done using a technique called policy optimization, which is a generic technique for training models to maximize any reward function. Since our reward function above represents human feedback, the LLM is in effect trained to maximize human preferences.
This process – RLHF – is a general technique for aligning the LLM with human preferences. To use RLHF for instruction tuning, the datasets used for RLHF include input texts that contain an instruction, rather than generic or undirected text. For example, the input text may be "Write a summary of this article" or "Write a review of this movie". The human-provided score is then based on how well the LLM follows the instruction. In this way, the LLM is trained to follow instructions.
There are a number of datasets that can be used for instruction tuning, falling into a number of ever-expanding categories. We will discuss a few of the most popular ones here.
- Summarization: In short, reduce the number of words in a document without changing its meaning. A few of the most popular include news summarization (Multi-News, XSum), conversation summarization (SamSum, AESLC), document summarization (WikiLingua) and others (Gigaword).
- Question-Answering: Answer questions about a particular document (SQuAD, OBQA) or about open-domain knowledge without a particular reference (ARC, TriviaQA, Natural Questions).
- Translation: Translate between a variety of languages (WMT-16, paracrawl).
There are a large number of other tasks that you can explore on the HuggingFace datasets and Google Research datasets webpages: text classification, text generation, and sentence similarity to name a few.
Evaluating an instruction-tuned model is still an active area of research, but there are a number of options available today:
- One option is to reuse the evaluation metrics that we applied to pretrained models, as the LLM worksheet does. However, finetuned models lose quality with respect to metrics such as perplexity and BLEU score. This is because those metrics are designed to assess how well the generated text is "aligned" with the training data, rather than the instruction-following quality. As a result, we need to develop new metrics that are better suited for instruction-tuned models.
- To do this, one recent approach uses user votes to compute ELO ratings, via the Chatbot Arena. ELO ratings are used generally for zero-sum games like chess to rank players; in this case, researchers from Berkeley proposed applying the ELO rating system to rank LLMs.
Neither are ideal as the first option doesn't directly assess instruction-following ability and the latter requires significant amounts of human annotation. We look to future research to address these issues.
This is a rough overview of the different stages involved in training a deployable LLM, from training to finetuning and instruction-tuning. We covered standard losses and metrics, as well as datasets used to train these models. With that said, there is still a lot to cover, to go from here to a complete training pipeline. We have the foundations so far, but stay tuned, as we'll cover more practical elements of training in subsequent posts.
Pretraining more broadly can refer to any training on a large(r) dataset. For example, in the context of computer vision, pretraining can refer to training a model on ImageNet, a 1.2 million image dataset that comes with labels for each image's class (e.g., dog or cat). In the context of LLMs, pretraining refers to training on a dataset of billions of words. Colloquially, "LLM pretraining" does not usually make use of labels. ↩
This notion of predicting rotations was first introduced by Gidaris et al. (2018) and has since been used in many other works, including Grill et al. (2020) and Liu et al. (2021). There are also other unsupervised learning tasks in computer vision, such as recoloring a grayscale image (Zhang et al., 2016) or rearranging shuffled patches of an image, like solving a puzzle (Noroozi et al., 2016). ↩
Granted, the model could also simply memorize all possible next words that follow "the quick brown fox". However, there are many more possible next words than the model has parameters, so the model can't in memorize all of them – at least, in theory. As a result, the hope is that the model is forced to learn a general representation of the data, rather than memorizing specific examples. Think of this as a form of regularization. ↩
"Red team attacks" are a way for crowdworkers to attack a model. You can learn more from Red Teaming Language Models to Reduce Harms: Methods, Scaling Behaviors, and Lessons Learned ↩
Want more tips? Drop your email, and I'll keep you in the loop.