Llama 2: Open Foundation and Fine-Tuned Chat Models
Original paper ยท Touvron et al 2023
Meta keeps on giving to the open-source community, this time around with a family of natural successor to Llama 1, and a set of open products that can viably compete with dialog style LLMs such as ChatGPT, BARD and Claude. Llama 2 is a family of pre-trained LLMs of the same size and with analogous training recipe to Llama 1. Llama 2-Chat is a completely new model family consisting of instruction fine-tuned versions of Llama 2, optimized for dialogue use cases. I've been excited about this paper for a long time so I'm glad we're finally here in this research journey.
This paper is extensive, Meta is very generous in the detail they provide surrounding both their pretraining and subsequent fine-tuning.
Pretrained Model Family
As mentioned, Llama 2 is a family of pretrained models created using a slightly updated Llama 1 approach. The training data has been extended to include new publicly available sources (outside of Meta's products and services) and now totals over 2 trillion tokens (up from 1T). The context window is expanded from 2048 to 4096 - this was equal to that of ChatGPT before they upgraded to 8k. Architecturally, the only unfamiliar module we find is Grouped-Query Attention (GQA), a key and value projection scheme that is reminiscent of Multi-Query Attention which we've covered on this blog before. What's great is that they've actually performed an ablation study comparing standard multi-head attention (MHA) with both MQA and GQA. They find that GQA performs closest to the MHA baseline and that it scales better than MQA. As a reminder, MQA/GQA involve sharing the key/value projections across multiple heads to offset the increasing memory costs of a growing KV cache size.
Compared to open-source base models, Llama 2 reports resounding performance on a suite of popular benchmarks across Code, Commonsense Reasoning, World Knowledge, Reading Comprehension and Math. Llama 2 70B is close to GPT-3.5 on MMLU and GSM8k but lags significantly behind on coding benchmarks. Results are on par or better than PaLM.
Fine-tuning Model Family
Llama 2-Chat is a dialogue style LLM that is supposed to be easy and safe to chat with. The fine-tuning process involves both instruction tuning and RLHF, something that requires a lot of resources. The process is very similar to what we saw in InstructGPT; Llama 2 is bootstrapped with a supervised fine-tuning stage and then iteratively improved upon using RLHF to align with human preference.
SFT
SFT is done using publicly available instruction tuning data. They found that a limited set of clean instruction-tuning data was sufficient to achieve high quality and therefor forgo millions of available examples. The data, each sample consisting of a prompt and a answer, was concatenated to a long sequence with answer and prompt separated by a special token. The model was then trained (fine-tuned) in accordance to the pretraining stage, with an autoregressive objective and zero'd out loss from the user prompts.
RLHF
Data was collected from human annotators, similar to that of InstructGPT. Annotators were asked to judge two model responses, on both helpfulness and harmfulness. This data was then used to train two separate reward models, one optimized for helpfulness and another for safety. Naturally, both of these models are initialized from the SFT checkpoint to avoid any mismatch in knowledge representation. The reward models are used to iteratively improve the performance of LLama 2-Chat models, as more batches of human preference data annotation was came in, a better reward model which in turn was used to train a better RLHF model. As opposed to InstructGPT with PPO, LLama 2-Chat is trained using Rejection Sampling (RS) fine-tuning. This was new to me, but the intuition behind RS is sound. For each prompt in the fine-tuning dataset, outputs are sampled and scored by the reward model. The best scoring output is considered the new gold standard and is used to update the RLHF model. The benefits of Rejection Sampling are apparent when you look at the delta between the maximum and median reward curves among samples
Learning and Observations
While the training recipes are interesting, intuitions and insights of researchers working on the project are almost more intriguing.
- Early, many of the project members expressed a preference towards the supervised annotation process. This seems natural given the dense signal provided, but to much doubt, RLHF proved highly effective, both in cost and time. The team underscore the crucial determinant of RLHF's success lies in in the synergy it fosters between humans and LLMs.
- Even with proficient annotators, a model fine-tuned on SFT annotation learns the true diversity, including the tail end of poorly executed annotations. Additionally, model performance is capped by the writing abilities of the most skilled annotators. During RLHF, the reward model swiftly moves away from the undesirable tail-end distribution and even move into areas where the best annotators may not. Intuitively, not all humans can be accomplished chefs, but our ability to appreciate and critique food remains intact. Superior writing abilities of LLMS, is fundamentally driven by reinforcement learning and supervised learning may no longer be the gold standard.
- There is believed to be a widespread emergence of temporal perception, despite training being solely based on next-token prediction. The models demonstrate a robust capability to organize its knowledge in a temporal manner.