This looks great.
The authors randomly permute (i.e., shuffle) input tokens in training and add two positional encodings to each token: one with the token's position and another with the position of the token to be predicted. Otherwise, the model is a standard autoregressive GPT. The consequences of this seemingly "simple" modification are significant:
* The authors can prompt the trained model with part of a sequence and then decode the missing tokens, all at once, in parallel, regardless of order -- i.e., the model can in-fill in parallel.
* The authors can compute conditional probability densities for every missing token in a sequence, again in parallel, i.e., densities for all missing tokens at once.
* The authors propose a rejection-sampling method for generating in-fill tokens, again in parallel. Their method seems to work well in practice.
I've added this to my reading list. Thank you for sharing it on HN.
I don't understand how that parallel prediction can work...
Let's say I give it as input the sentence:
I . . . . . . . . happily.
The second word to be predicted depends on the first word.
Give the model the tokens "happily" and "I", and add to each input token its respective position embedding and the position embedding for the token to be predicted. You can do this in parallel for all tokens to be predicted. The model has been trained so it can predict tokens in any position.
Yes, but is there any guarantee that the complete sentence makes sense?
That guarantee didn't exist with regular GPT LLMs, did it? It just came about as an emergent property of throwing more and more compute, training data, and training time at the problem
I think it’s effectively built in to the design. The model outputs a probability distribution for the first unknown token [0]. Then some code outside the model chooses a token and runs the model again with that token provided to the model. So the second output token’s probability distribution is automatically conditioned on the first output token, etc.
Sometimes people will attempt to parallelize this by using a faster model to guess a few tokens and then evaluating them in as a batch with the main model to determine whether the choices were good.
[0] Usually it outputs “logits”, which become a probability distribution when combined with a “temperature” parameter.
It isn't. There is no guarantee that successive tokens will be comprehensible.
The logits are the probability distribution (well technically, you would apply softmax). Temperature is a parameter for how you sample those logits in a non-greedy fashion.
I think temperature is better understood as a pre-softmax pass over logits. You'd divide logits by the temp, and then their softmax becomes more/less peaky.
Sampling is a whole different thing.Sure, my comment about softmax was simply about the probability distribution. But temperature is still part of sampling. If you’re greedy decoding, temperature doesn’t matter.
No, but it makes more conceptual sense given the model can consider what was said before it
That is indeed an issue. Their sampling method rejects impossible combinations.
Isn't this bag of words all over again? Except with positional hints?
Off topic, but what do you use for your reading list?
I use Emergent Mind[1] to keep track of new research published on ArXiv. You can bookmark articles once logged in. It's very useful for keeping track of articles, reading quick summaries, and following conversations on various social media.
[1]: https://www.emergentmind.com/papers/2404.09562
hijacking for a bit of shameless self promotion: if you're an obsidian user, I recently built a plugin that simplifies web pages, parses out metadata, and saves them to obsidian as markdown files: https://github.com/inhumantsar/slurp
arXiv comes through a bit ugly atm but it's high on my to-do list. I'm leveraging the library that Firefox uses for reader mode, so most sites come through quite well. A lot of my work right now is expanding their metadata support and fixing parser issues.
old-fashioned text files
Google Chrome has a built-in reading list (go open the 3-dotted menu at the top-right corner, then click on "Bookmarks and lists" -> "Reading list")
Zotero is great for organizing and annotating papers, keeping notes, and building bibliographies.
You can create libraries and sub libraries according to topic, and also create libraries for projects or reading lists. You can file items into multiple libraries, and you can also create shared libraries, allowing your team to share annotated papers.
Finally it can archive offline copies of web pages, which makes it useful for blog articles and other online resources that might vanish.
There's a learning curve, but it's worth it if you find yourself juggling dozens or hundreds of technical papers! Enjoy!
The only difference I see from XLNet is how they use it during inference.
Hey! I'm Arnaud, first author of the paper. XLNet also shuffles the data during training, but they use a masking mechanism instead of the causal + double positional encoding. The application differs, XLNet is not AFAIK focused on generation (even if it can be used for that) and the burst-sampling idea is new.
Are there any obvious practical application of this algorithm for existing large (10B+) text / image models?
Does the rejection sampling lead to a statistically correct sample from the joint probability distribution or is that just a (possibly rough) approximation?
For the application: being able to prompt anywhere in the sequence can be of interest. For what we've seen in the experiment, the rejection sampling leads to similar generation than the autoregressive one, we did not see any mode collapse or anything of that kind.
Thanks for the clarification!
I know this is for tokens/text, but can the same concept be applied to images using something like a diffusion model? And then be able to upscale images arbitrarily by infilling?
Yes. See the related work section in the paper: there is a long history of models, recently like MAE and MaskGit, which predict pixels in basically arbitrary orders, and that is useful because it lets you train on subsets of each image, upscale/infill during generation, and so on. (If you know what MAE is, that might be the fastest way to summarize OP: "it's a GPT trained like a MAE".)
People also often forget "orderless autoregression", which was introduced a while back and has been reinvented many times since. See Sec 4 (pg 8) of "Neural Autoregressive Distribution Estimation" [https://arxiv.org/abs/1605.02226]. The main difference from current work is that this 2016 paper used MLPs and convnets on fixed-length observations/sequences, so sequence position is matched one-to-one with position in the network's output, rather than conditioning on a position embedding. Of course, Transformers make this type of orderless autoregression more practical for a variety of reasons -- TFs are great!
Key quote from Sec 4: "In this section we describe an order-agnostic training procedure, DeepNADE (Uria et al., 2014), which will address both of the issues above. This procedure trains a single deep neural network that can assign a conditional distribution to any variable given any subset of the others. This network can then provide the conditionals in Equation 1 for any ordering of the input observations. Therefore, the network defines a factorial number of different models with shared parameters, one for each of the D! orderings of the inputs. At test time, given an inference task, the most convenient ordering of variables can be used."
If there are multiple missing tokens, what's the positional encoding for the "token to be predicted"?
See this thread, also on this page:
https://news.ycombinator.com/item?id=40609689
This problem formulation has been around for a while, it’s kind of the holy grail of modeling. What is new compared to PixelCNN and related is this position embedding idea.
Wow, if that works that's wild (and also has that "damn, now you say it it's obvious" flavour that so many really cool discoveries share...)
Wait wasn't BERT all about non-causal masking aka predicting words in the middle?!