XLNet: Generalized Autoregressive Pretraining for Language Understanding

Original paper ยท Yang et al 2019

This paper is a natural extension to the bidirectional modeling proposed by BERT in 2018. BERT uses a masked input to train a transformer model to learn bidirectional relationships in a sequence as opposed to the autoregressive model, GPT. This bidirectional understanding, or modelling, is crucial for certain natural language sequences. However, this approach neglects the relationship between the masked tokens and may therefor lead to worse downstream performance. Real data, used in fine-tuning or prediction, doesn't contain corrupted input which introduces pretrain-finetune discrepancy. Moreover, autoregressive language models seek to estimate the probability distribution of the text corpus, resulting in a joint probability distribution using the product rule. BERT is not able to model this probability since it inherently assumes that the predicted tokens are independent of each other given the unmasked tokens. Given these advantages of AR and BERT over each other, the authors introduce a new model XLNet - a generalized autoregressive pretraining method which seeks to bring the advantages of both discussed approaches.

Method

First, the authors address the issue of capturing bidirectional context within an autoregressive model. To do this they borrow ideas from NADE, a model I've yet to cover on this blog, and propose a permutation language modelling objective. For a sequence xx of length TT, there are TT! different orders to perform valid autregressive factorizations. Important to note here is that the actual order sequence remains unchanged, using the positional encodings of the original sequence and relying on attention mask to achieve permutation of the factorization order. Although this achieves the desired properties it introduces a number of caveats. If you are interested, they are covered in detail in the paper; I found them quite difficult to follow so I am going to skip over them.

Since the objective function of XLNet fits into the AR framework, the authors incorporate Transformer-XL into their pre-training framework and name the method after it. Transformer-XL has two very important techniques to improve on the original Transformer architecture, both of which are implemented in XLNet, namely: (1) relative positional encodings and (2) segment-level recurrence. Check out my blog post for an overview.

Comparison to BERT

The authors train a comparative model XLNet-Large-wikibooks on BookCorpus and Wikipedia only, reusing all pretraining hyperparameters as in the original BERT to provide a fair comparison. Pre-training is then scaled up for XLNet-Large by introducing Giga5, ClueWeb 2012-B and Common Crawl resulting in a total of 32.89B tokens. The models are trained using Adam weight decay, linear learning rate decay and a batch size of 8192.

To decouple the effects of using more data, BERT is compared fairly to XLNet by using XLNet-Large-wikibooks. Trained on the same data, with an almost identical training recipe (model and dataset hyperparameters), XLNet outperforms BERT across most considered datasets. XLNet-Large is then compared to BERT and a few other pretraining models such as RoBERTa using tasks such as reading comprehension, question answering, text classification and natural language understanding. On all but a handful of these experiments, XLNet outperforms both BERT and RoBERTa, which points to the strengths of this bidirectional AR approach.

Final thoughts

XLNet is an impressive model, it achieves a generalized AR pretraining method using a permutation language modeling objective to combine the benefits of AR and AE. The model incorporates several techniques from the SOTA Transformer-XL model to achieve substantial improvement over previous pretraining objectives on various tasks.