Jamba boasts an extensive context window of 256K tokens, equivalent to around 210 pages of text, while fitting up to 140K tokens on a single 80GB GPU.
I realize this is a big improvement, but it’s striking how inefficient LLM’s are, that you need 80GB of GPU memory to analyze less than 1 megabyte of data. That’s a lot of bloat! Hopefully there’s a lot of room for algorithmic improvements.
It’s literally simulating a neural network.
How much of your 5-sense experiential memories and decades of academic book learning are you bringing to understand my reply to your post?
How many gigabytes do you think that’s equivalent to?
It’s kinda simulating our brains but not really. When I attempted to dig more into how neurons work I realised that it’s a massive chasm of difference. Very much worth doing if you haven’t (you might know far better then me, this is for people who don’t yet.)
In terms of results: Our brains are working with 20w of power and can be trained to compete with LLM’s using a tiny fraction of the world’s data. They also have to keep you breathing and your blood pumping and manage all the dangers of catching a ball near traffic. Or skiing, or poetry, or sunsets. And they remember stuff five minutes later and don’t need a training run that takes months.
We have SO many opportunities to improve the AI architecture it’s ridiculous. This is a good thing.
To be fair most of the brain is more like a pretrained model — it isn't being trained at any point after conception to keep your blood pumping or your lungs working, it does that out of the box roughly as soon as you sprout those organs (or the minute you're born, in the case of lungs). The training process was billions of years of evolution. And, well, given fairly persistent cross-cultural cognitive biases, I expect the conscious thought parts are starting from a pretrained model, too, and all we're doing in school is finetuning ;)
People don't understand that to simulate a single neuron, you need an entire neural network. So 70 billion parameters might at best be equivalent to a million neurons but that is assuming that your neural network architecture is akin to the connections between neurons. Considering the physical sparsity, you might need even more parameters to model the connections of a biological neural network. So less than a million neurons in practice.
Jamba seems to be distributed as 21 5-gigabyte files [1] so I guess that’s another way of looking at it.
[1] https://huggingface.co/ai21labs/Jamba-v0.1/tree/main
So what? I have seen models distributed as 26x 10GB files.
I love both parent post perspectives on this.
The big (huge?) memory requirement is during training. These LLMs work with high dimensional vectors and they calculate gradients with respect to high dimensional vectors and they do updates that require state of the optimizer. If you have 3 particles in 3 dimensions and you need their forces that creates 3 new 3D vectors and once you update their position along the forces then they also carry momenta. Now generalize these simple 3-body physics to the typical 60-layer creatures inside the LLM with vectors of several thousand dimensions, interactions/weights that are scaling like the squares of these vectors, to a total parameter count that adds up to the 10s to 100s of billions of parameters, and then take derivatives and start to keep track of momenta. It is a feat of modern engineering that some groups can train such models efficiently. I hope we will see more of the training stories becoming public in the near future.
This is wrong. You need big memory during inference too.
The difference there is you can use tricks like quantisation and offloading to CPU to reduce it somewhat at the cost of accuracy and/or speed.
Not sure what you mean by wrong. I have never encountered a case yet when training an LLM (no matter what architecture) would require limited memory and was pointing out that the typical memory requirements for training are much higher yet than the typical requirements for inference.
Training is 3x the memory used by inference, and usually run at a much larger batch size
Two things I'm curious to know:
1. How many tokens can 'traditional' models (e.g. Mistral's 8x7B) fit on a single 80GB GPU? 2. How does quantization affect the single transformer layer in the stack? What are the performance/accuracy trade-offs that happen when so little of the stack depends on this bottleneck?
Mixtral 8x7b runs well (i.e., produces the correct output faster than I can read it) on a modern AMD or Intel laptop without any use of a GPU - provided that you have enough RAM and CPU cores. 32 GB of RAM and 16 hyperthreads are enough with 4-bit quantization if you don't ask too much in terms of context.
P.S. Dell Inspiron 7415 upgraded to 64 GB of RAM here.
80GB is compressed all human knowledge applied on that 1mb..
That’s all the world’s knowledge compressed into 80GB. It’s not analysing 1MB data, it’s analysing all of that knowledge plus and additional 1MB.
Compared to the human brain they are shockingly efficient. It's the hardware that isn't, but that is just a matter of time.