I believe the future is 1 bit models - for both training and inference.
When people make custom silicon for 1 bit models, they'll find that it is sooooo much more power and silicon-space efficient to do 1 bit math than 16 bit floating point - like 100x or more.
That extra model size will vastly overshadow any worse performance of the models.
For training how do you get any kind of meaningful derivative with it?
Maybe evolutionary algorithms instead? Hasn't proven super useful historically, but maybe at the scale of enormous LLMs it will be?
Nope, they're orders of magnitude more inefficient because they don't leverage gradient descent.
Rule of thumb in optimization: real numbers are easy, integers are hard
This may be the status quo because of the so called "hardware lottery" which has historically been optimized for floating point. I'm speculating, but if hardware designers were instead only concerned about raw xnor density and throughput, we might end up with chips powerful enough that giant 1-bit nets could be trained purely through evolution.
How do you optimize memory for floating point?
No, it's a fact at the mathematical level that you can enshrine in big O terms if you want to
Evolutionary algorithms made you, didn’t they?
It took a lot of human brain flops to get to this point in time though, I wonder how many orders of magnitude more than it took to train ChatGPT...
That does not prove that they can beat gradient decent.
Gradient-directed evolutionary algorithm sounds kinda interesting.
The OP explicitly excludes training.
The one I replied to said 1-bit for both training and inference.
Maybe something probabilistic?
I believe the future is 4*4 bit look up tables with output latches, with a bit to/from each Cartesian neighbor. Clock them like the colors of a chessboard, in 2 phases, and you don't have to worry about timing dependencies.
All of your code gets converted to a directed acyclic graph(DAG), executing at Ghz rates if you want.
Imagine a machine that can output a million parallel GPT-4 streams at 1000 tokens per second each.
If the architecture changes it's just a different DAG. Unlike with FPGAs and their custom blocks that have to be optimally used, you can compile a DAG almost instantly.
Is this something from a research finding or is it your idea?
It's my idea... I've been writing about it for over a decade. I hope to get a small version of it made via tiny tapeout later this year once I'm out of the precariat.
The idea is that you can decompose any non-cyclic code, for example, a matrix multiply, FFT, etc.. into a directed graph of bitwise operations, then map those operations out into the grid of LUTs. Pipe inputs in one side of the grid, and get the outputs out the other side, and all of the bitwise operations happen in lock step parallelism. (Unlike an FPGA which is asyncronous)
If you can decompose one stage of an LLM down into a bitwise graph of computation, you can easily map it into the grid. If there's a bad cell in the grid, you can map around it.
Because all of the high speed logic lines are short (to the neighbors) you don't have the long lines / high capacitance issues driving signals all the way across an FPGA or ASIC, thus you can really crank up the clock rate (and/or it's very power efficient).
It should be trivial, design wise, to scale from a 16x16 grid up through chips that could crank out Petaflops.
Here's a simulator for the thing, written in Pascal.
https://github.com/mikewarot/Bitgrid
You can simulate this on a GPU now. This is like running a CA on a GPU with slightly different rules.
I think you would also be delighted to know that many of the non-GPU ai accelerators are in a mesh topology.
Cerebras, Tenstorrent, Esperanto
https://www.esperanto.ai/wp-content/uploads/2021/08/HC2021.E...
I've been intrigued by this approach. On a more highly optimized (but harder to program) take is the GA144[1] from Chuck Moore, the inventor of Forth. It's a grid of 144 F18 Forth based processors in a cartesian grid. These processors are far more limited, but then again they take far less power as well.
[1] https://www.greenarraychips.com/
1. If you write FPGA code as a grid of lookup tables then I would expect it to be easy to compile instantly.
2. In what way is this "acyclic"?
3. Won't putting your code into this form be the hard part? Even if you start with a DAG, 99.99% of them won't fit this form without intense restructuring. So you just moved the hard step over by one.
Isn’t 1bit too low for optimal radix economy (Euler’s number) though?
I want to see somebody try “imbalanced quaternary” -,0,+,2
Haven't heard this argument before. But from the Wikipedia article it seems base 3 has the best asymptomatic radix economy, but isn't much better than base 2 and base 2 is seemingly easier to program and optimize.
Since this is a new argument I've not heard, would be curious if you had links or some more explanation.
people are indeed working on -1,0,1,2 Q2 models, I read something about it the other day but don't recall the title.
they mentioned decomposition of Q2 into ternary+binary i.e. [[1,2],[-1,0]] -> [[1,1],[0,0]] + [[0,1],[-1,0]]
Radix economy is all about which base is the most efficient to represent a given number. It is simple to show that, for large numbers, this is equivalent to how efficient a base can represent itself, b/ln(b). Simple calculus shows this is minimized at e (Euler's number) or 3 if integer (closely followed by 2).
It sounds like you have something to add but you are already dictating the base by saying "bit". Literally from "binary digit". Anyway, quantization is not about which number system is best - virtually all computer systems we use today represents numbers in base 2. Quantization, at its core, is lossy compression. How do you go from a large model trained to high precision to a smaller model without hindering performance? This can be studied without needing to know the base.
Suppose you are using a decimal computer. You can ask yourself, I have a 128-decimal precision numbers, do I need that much precision? What happens if I just use 1-decimal precision by chopping off the 127 digits after the first decimal? You realize that there are two parts of an operation. The numbers involved (the operands) and the operation itself. You then ask yourself, if I keep one of the operands fixed (the original input), can I represent my 128-decimal precision neural network simply as a series of operations without the other operand? Perhaps only the most basic ones? Like: noops (add 0 or multiply by 1), increments (add 1), decrements (subtract 1), negations (multiply by -1), and clears (multiply by 0)? You count those numbers (-1, 0, and 1). There are 3 so you proudly proclaim you've made a neural network that only uses 0.477 dits. People get excited and confused because that is less than 1 dit which seems like a fundamental point. You further surprise the scientific field by finding a clever trick for getting rid of negations. You beat your previous record and now you only need 0.301 dits to represent your network. You are about to accept your Turing reward when the ghost of Claude Shannon appears and says "Why are you using a unit that measures entropy to mean how many symbols you have? If you insist, at least realize 0.301 dits is 1 bit." You are shocked when you realize 10^0.301 = 2^1. Reviewing Shannon's seminal paper, you are awestruck by Shannon's prescient comment "Change from the base a to base b merely requires multiplication by log_b(a).". You humbly give your award to Shannon. You keep the $1M since ghosts aren't as fast a new NVidia DGX. No matter how quantized the ghost is.
[1] - https://people.math.harvard.edu/~ctm/home/text/others/shanno...
I bet the optimal "large" value is bigger than 2.
Probably more than 100x for inference. Not only are you drastically reducing the number of bits and replacing float math with integer math, you can do matrix multiplication with only addition (as pointed out in the BitNet b1.58 paper). Additions require a lot less hardware to implement than multiplication. Adding one-bit or two-bit numbers requires barely any hardware at all. A traditional two-bit adder without carry bit is three xor gates and an and gate.
to me the most exciting thing is that if is training that is speed up on the order of 100x-1000x, a large cluster may be well suited to gradient-descend hyperparameter tuning parameters by LLM training again and again at scale -- this is the first foot in a door towards an AI that iteratively may improve itself
LoRA training should benefit from the same speed-up, because the 1-bit weights will be frozen and all you need for both the forward and backward pass is a binary matmul, then maybe cast after to get more stable gradients.
Doesn’t training need higher precision to avoid getting stuck at local minima, at least with back propagation style learning?
Maybe something alternate like evolutionary algorithms could work in a domain like this, but so far those have proven to be less efficient.
A recent paper used a single ternary "trit" ~1.5 bits per parameter for training. https://news.ycombinator.com/item?id=39535800 They said it would be difficult to apply this technique to pre-trained models and had to be trained in 1-trit from scratch.
But it's using larger weights during training, not just the trits.
...What?
I think OP was referring to parameter size. You can make up for quantization by having more parameters.
I know. I was being a bit mean, but they seem to be implying that more parameters even with worse performance is preferable. That seems backward to me. There is no reason for parameter size to be intrinsically valuable, and I think people are a bit infatuated with it.
The 1 ternary bit models only compress the weights. You still add and subtract using bfloat16 for better accuracy. Dedicated silicon is mostly a waste, because you are only processing two operations per parameter during inference. Loading the parameters from slow DDR, GDDR or HBM memory is the bottleneck in practice and the only solution is PIM. I was honestly disappointed by Nvidia's Blackwell since it is just barely competitive with GDDR PIM.
At least you can copy 16 times more data to the shared memory with binary weights.
1 bit's nothin'. The future is training directly on electronic/photonic circuits.
This! Maybe just integer, but not floating point. That's a ridiculous way to do computation when you don't really need the precision.