This piece touches on a few related questions, including:
- Why does training an LLM on the internet’s data work?
- Why does prompting with “You are an expert” help?
- Why is chat fine tuning (reinforcement learning/DPO) useful?
- Why is lower perplexity not always better?
In short, language models work because good ideas tend to be replicated (not because replicated ideas make those ideas more true).
*Isn’t it surprising…?*
…that training a language model on the web’s data works? Isn’t there a lot of rubbish on the web? How does the model know what information to focus on?
“Correct” or “high quality” data has a tendency to repeat itself. There might be some data saying that the sun moves around the earth, but there is much more saying that the earth moves around the sun. This makes it possible for us to train language models on large datasets, even if they contain some, or even a lot, of weak information or arguments. Weak arguments and information tend to differ from each other, whereas stronger arguments or information tends to be articulated, transmitted and replicated more coherently.
There’s a very important directionality to be aware of, which is that i) good explanations/ideas are more likely to be replicated NOT that ii) replicated ideas/explanations necessarily make them true:
~~~
i) X is a good idea/explanation => X is more likely to be spread in written/oral format.
NOT
ii) A majority people/experts say X => X is necessarily true/correct/better.
~~~
This gets to the heart of why language models currently cannot reason like humans. Language models are frequency based, not explanation based. There was a time when many people thought the Earth was the centre of the universe. Not only was that wrong, but we didn’t get around it with new data. We got around it with new explanations. The same is true with general relativity improving over Newton’s laws - it wasn’t brought about by new data, but a new explanation.
You may feel I’m digressing a bit. I’m not. Understanding language models as STATISTICAL MODELS helps us understand a lot about what works well for training and why.
Getting back to the narrow point…The internet is a reasonably good dataset to train an LLM on because humans tend to replicate better ideas/information, i.e. the frequency of good ideas and information is sufficiently high above the bad information.
Still, the “poor/inaccurate” information blurs the model’s answers. Think of the internet’s data as rolling hills with the model always trying to climb to the highest peak. The more rubbish dumps dotting the landscape, the harder it is for the model to identify the highest peak. The cleaner the hills, the sharper the tallest peak and the easier for the model to identify the best answer.
*Why does it help to tell the model it is an expert at coding?*
A language model has no way to identify good data or bad data other than to look at the probabilities based on the data it has seen.
A model trained on the internet will have code from GitHub, text from wikipedia and much more. Now, you pre-pend your question with “You are an expert at coding…”.
The meaning of that phrase is going to appear a lot more in the Github data than in the rest of the web’s data. This drags the distribution of probabilities for words the model will answer with towards those in GitHub. On average, these answers will be more related to code and will be better answers!
Training techniques all involve some form of statistically dragging the model towards high quality, relevant information, and dragging it away from low quality information.
*Why data quality matters*
You’ll read everywhere that data quality matters. The LIMA paper teaches “less is more” for certain fine-tuning - provided the data is of very high quality, i.e. it’s better to have some great data than a lot of bad data.
If you train a model on data about a house, a garden and the wheely bins outside, but there’s nothing good in the bins, then you’re better off to leave out the bins altogether so that they don’t add noise to the statistical distribution.
There’s a model called Bloom 176B with 176 billion parameters. It’s being crushed by much smaller models, probably even 7B models in some cases like Zephyr. And that isn’t because of architecture - they’re all using transformer architectures. It’s just that everyone is using much cleaner data (and smaller models are being trained with more data, rather than larger models with less data, see the Chinchilla paper, although training is now done way past Chinchilla).
Mistral isn’t much different from Llama 2 in architecture (apart from tweaks to the attention layout) but it’s probably trained for longer AND has better data. Zephyr is better that Mistral, but it was only trained for a few hours!!! That’s data quality, not architecture.
Companies are going back to smaller models - Llama 2’s biggest model is 70B parameters. Xai’s model is around 30-40B parameters.
*Why do big companies do Reinforcement Learning?*
Reinforcement Learning has been seen an important “alignment” step done on models after initial training to make models safer and more “human like”. I’ll give a different articulation here - reinforcement learning shifts the model’s statistical distribution away from bad answers towards good answers.
Standard training - where you take part of a sentence, get the model to predict the next word, and penalise/adjust the model based on how far its prediction was from the actual next word - doesn’t inherently allow the model to distinguish between good and bad data. With standard training the model can only distinguish between good and bad based on the frequency of what it sees.
Reinforcement learning adds a different tool allowing us to push the statistical distribution away from bad answers and towards good answers.
As “not just a side note”, Reinforcement Learning was very complicated - involving the training of an entirely separate helper model. This is no longer the case with Direct Preference Optimisation (DPO), which shifts the model from bad to good in just one training step. So, the message that reinforcement learning is only possibly by big companies is changing.
In DPO, you take a prompt and create (with a human or language model) two answers - a good and a bad one. Then, you run those prompt+answer pairs through your model. For the “good” answer, you take whatever probability the model predicted for generating that answer and increase that probability by an amount, say by a factor of beta relative to the raw model (then backpropagate through the model). For the “bad” answer, you take whatever probability the model predicted for generating that bad answer and decrease the probability.
The good+bad answer dataset is expensive to generate by humans, but is powerful in how it can shift the model towards using better parts of its statistical distribution. This is why chatGPT will occasionally show you pairs of answers and ask you to choose between the two (presumably, to use for training). Alternatively, you can also use a stronger language model to generate prompts to train a weaker model (e.g. Zephyr was trained by data curated by gpt4 and llama and falcon and mpt).
DPO or Reinforcement learning statistically drags the model away from the trash can and towards the house. In a sense, reinforcement learning or DPO allows models trainers to make up for the fact that their datasets have bad explanations and information.
It’s pulling the model away from bad answers and towards good ones. Of course, it would have been better if we didn’t have the bad quality answers in there in the first place!!!
*Lower perplexity is not (necessarily) better!*
Perplexity is a measure of how much a language model deviates in representing it’s training data (or some smaller benchmark dataset, sometimes wiki text - a sample from wikipedia).
When an LLM is being trained, the perplexity goes down and down until eventually it plateaus. (i.e. the model starts to get closer and closer in representing its training data).
Interestingly, certain types of training, like reinforcement learning (or now, DPO), can increase the perplexity of a model! Why??? And is that good or bad?
Quite simply, DPO moves the statistical distribution of a model away from bad data and towards better data. It’s like moving the model away from the rubbish dump to focus on the rolling hills. If your perplexity benchmark includes the data set for the dump+hills, then of course the perplexity measurement can go up because your model is no longer representing the dump!!!
Low perplexity is not a goal in itself. The goal is low perplexity on the highest possible quality dataset! This is why DPO can (and probably should) increase perplexity.
*What makes a great language model?*
Good ideas and information tends to be more frequently replicated in language datasets. To draw a Dawkins analogy, biology replicates genes and humans replicate ideas - with a tendency towards replicating the better ideas. By “replicate”, I mean spread in written and oral form - i.e. language. By building a statistical model of written datasets, we tend to be able to generate the better ideas and information (because those are the ideas and information that we tend to replicate/spread) - because that’s what those datasets tend to contain more of!
Since the model itself, other than analysing statistically, doesn’t know which ideas are better, we do the following to improve results:
- Filter data to remove bad answers and messy data.
- Shift the statistical distribution away from remaining bad ideas and towards the good ideas/info. This is done through reinforcement learning (now DPO), as I described above, using pairs of good and bad answers to questions.
- Even then, with a finished model, we can further shift the model towards relevant parts of its distribution by saying things like “You are an expert on coding” OR “You are an expert on biology.”
To sum up, a great language model is one where we create and drag it’s statistical distribution towards the highest quality data and ideas represented in language datasets.
A great LLM is a great compression algorithm.