PaLM: Scaling Language Modeling with Pathways

Original paper ยท Chowdhery et al 2022

Google introduces PaLM, boasting a staggering 540 billion parameters - one of the largest and densest models in existence.Training this powerhouse required a substantial arsenal of 6144 TPUs, a resource limited to the top players in the field. While this model boasts impressive benchmark performance, it was trained at the same time as Chinchilla making it quite suboptimal. Remember that Chinchilla was 70B trained for 1.4T tokens - PaLM is 540B trained for 780B tokens making it very overparametrized.

Model and Training

PaLM is a standard decoder only Transformer model architecture - trained as a autoregressive LM. The corpora consists of 780B tokens coming from a mixture of webpages (sampled according to quality score), Wikipedia, news, articles, and code. Training is done for exactly 1 epoch. PaLM observe a training throughput of 238.3K tokens/sec resulting in a model FLOPs utilization of 46.2% - higher than any previous LLM according to the authors. Additionally there are a couple of interesting architectural details to note:

  • SwiGLU activation function as opposed to the standard GeLU or ReLU.
  • Parallel formulation in the Transformer block as opposed to the standard "serialized" formulation. This means the MLP and Attention input matrix can be fused resulting in a 15% speedup. Below is the standard formulation followed by the parallel.
y=x+MLP(LayerNorm(x+Attention(LayerNorm(x))))y = x + \text{MLP}(\text{LayerNorm}(x + \text{Attention}(\text{LayerNorm}(x)))) y=x+MLP(LayerNorm(x))+Attention(LayerNorm(x))y = x + \text{MLP}(\text{LayerNorm}(x)) + \text{Attention}(\text{LayerNorm}(x))
  • Multi-query attention as covered in a previous post on this blog. Standard multi-headed attention uses kk attention heads projecting the input vector into query, key and value vectors of size [kk, hh]. In MQA only the query vector is [kk, hh] while the key and value vectors are projected as [1, hh] and shared across the attention heads. This has significant inference speed-up with a neutral effect on model quality.

  • RoPE Embeddings as opposed to absolute or relative embeddings.

  • SentencePiece, which seems to be standard at Google, with a 256k vocabulary to support the multi-lingual dataset.

Final Thoughts

Although suboptimal, PaLM produces impressive scores on the evaluated benchmarks and clearly proves that LLM are able to outperform human across a number of NLP tasks. One of the key things that the authors note is the ability on reasoning tasks that emerges from chain-of-thought prompting where the model is explicitly prompted to generate a natural language logical inference chain before making its prediction.