How much is the flash attention algorithm tied to the hardware? For example, in this announcement they mention taking advantage of the async capabilities of the H100 GPUs which I assume means you won't get those speedups on non H series card. Two, the actual flash attention library requires CUDA, although the algorithm has apparently?[^0] been ported to metal. I would imagine if the algorithm was literally just a pure function it could be implemented for any GPU/ML framework?
Compiler folks: Is there any chance compilers will be able to find optimizations like FlashAttention on their own? Seems like TVM and tinygrad are working in that direction but I find it hard to believe that that would be feasible
No. Think of it like a different algorithm. You just take the shape of the hardware into consideration when designing the algorithm instead of considering math only.
Seems like TVM
Fair enough, though technically they are still about different things but it's indeed very close, but
and tinygrad
?????? what gives you this impression?
What's the distinction between what TVM does and FlashAttention type optimizations?
There is more than layout / tile schedule in FA. For example, first, to be able to fuse all these together [0] at all, you need to "decompose" the softmax to make it combinable, which requires maintaining some extra statistics. Won't gonna repeat the math here as the original FA paper is already very clear.
[0] so you can avoid materializing intermediate matrices and still being able to compute in blocks.
Geo has explicitly stated he wants to be able to find FA in the search space of algos eventually. Actually achieving that is another matter.
In theory, yes, it's "just" some algebraic properties of the math used that allow for substantial reordering, and then you'd add fairly regular polyhedral loop tiling. Just expensive to do, so you'll have to cache the effort.
The area of e-graph optimizers seems well-suited to this, btw. It's not really deployed outside of some niche tooling though, as it's a big paradigm shift in optimizer pass handling (e.g., doesn't work well with chairs classic call graphs, so control flow needs to be massively revamped to deploy e-graphs outside/across basic blocks and for loops (break and return not supported!)).
Just discovered e-graph recently and I have a good understanding of compiler from taking compiler class at university.
I would like to understand why you say e-graph would need control-flow to be revamped.
Do you have anything I could read on it ?
Kinda tricky if you want to call higher level operators in a wrapped language like Python.
This strikes me as an extremely difficult but not intractable problem.
I'm not sure what the state of the art in compiler optimisation is with regard to data positioning and targeting maximum processor usage
There was a video on optimisation a while back that showed small optimisations caused increases in speed that were insignificant when compared to the speed variance induced by the memory layout that the optimisation (or even a random change) caused.
While that talk was more focused on getting a signal past the noise. That noise itself is an artifact of compilers being not particularly good at handling a much simpler form of the problem you describe.
CPU and memory architectures are complex when caches and access patterns impact upon speed.
When you add in GPU architectures to the mix I think you might be in fairly uncharted territory.
Maybe one day.
Of course since we are in the field of AI there is also the question of could a sufficiently smart AI do this. It depends on the value of sufficient.
I would like to think that an extremely high level test for an AI model could be to give it something like micrograd and tell it to produce something with the same interface that outperforms torch.
We're not even in the ballpark of being able to do that yet, but it will be interesting when and if that happens.
If anyone wants to port this over to ROCm / AMD MI300x, reach out to me: hello@hotaisle.xyz (we won't ever spam you).
Happy to donate the compute time for this work.
Not trying to be rude but what is the thinking behind this offer? Why would someone do this port for…free save for access to the hardware? What’s the upside for them?
Not a rude question. I'm building public HPC super computers, currently focused on AMD hardware. The one I'm about to deploy is Top 150, which is a pretty good start.
The goal is to encourage a developer flywheel. The more developers working with AMD hardware, the more hardware that is needed, the more hardware I can justify buying, the bigger my super computers get.
Nvidia has been doing the flywheel for years and it has clearly worked. Why not do the same for AMD? As I said in another thread, anyone who thinks that there should be a single provider for all of AI compute needs, will be on the wrong side of history.
No one person or one company SHOULD have huge control over humanity, I agree.
But practically speaking this is a unique time in history of technology because there are quick feedback loops that cause that flywheel you mentioned to be a insurmountable first mover advantage.
But practically speaking this is a unique time in history of technology because there are quick feedback loops that cause that flywheel you mentioned to be a insurmountable first mover advantage.
I'm staking my career and business on you being wrong about the insurmountable part. This is just the beginning of a long road and I'm not the only one who believes this. My partnership with Dell, Advizex and a huge soon to be announced datacenter company, isn't small beans.
Much like how I didn't know how the internet would look when I first joined in 1991. But, what I can see very clearly from my decades of experience in the tech field, is that history is repeating itself with what is happening in AI.
As I'm also prone to say... this isn't a football match where one team needs to "beat" the other. It really is enough to have multiple players in the market and nothing more than that. In fact, I'm more than happy to deploy any type of compute that my customers want me to deploy for them, including Nvidia.
Even Lamini, whom were previously AMD only, just announced [0] that they are partnering with Nvidia. Their software will run equally well on any system. Why? Because it builds a simple bridge from one platform to the next. Reminds me of the Java "write once, run anywhere" slogan. It actually worked pretty well.
I'm not saying it is impossible for other companies to build good and profitable products. Google, AMD, Tesla all have good AI systems.
I'm saying NVDA uses their own chips to help build more chips, AND they are intricately involved in the buildout of the 100B data centers and intricately involved in TSMC roadmaps. That with the combination of huge profits that are increasing create even more advantages over competitors.
Obviously this doesn't go on forever, NVDA will never have 100T of profit in a quarter. Years from now the feedback loops will have diminishing returns and there will be commodity AI systems eventually.
I did not use the word impossible. Nobody is arguing that Nvidia won't be the dominate player for a long time. That does not mean that there isn't a good business in being in the game.
Years from now the feedback loops will have diminishing returns and there will be commodity AI systems eventually.
Maybe, but the cat is out of the bag. Before it was a question of Moore's law and speed, but nobody talks about that anymore... all they talk about is that the need for raw compute (not even the fastest), is officially boundless.
You're the AMD accelerator server company! Such cool work, hope someone takes you up :)
FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
How does FA3 fare for consumer GPUs such as 3090 and 4090?
It's Hopper-specific, the improvements are closely tied to Hopper features like warp groups and TMA. For 4090s, you might get a speedup by using the Triton implementation of FP8 attention: https://triton-lang.org/main/getting-started/tutorials/06-fu...
The original flash attention (v1?) took like a year to get added to llama.cpp and only provides single digit percent VRAM savings for typical context lengths and practically no speed boost. Still nice to have, but man was this thing overhyped. I doubt v3 will do more than marginally better on the RTX 5000 series.
On GPU, or on CPU/Metal? For the latter I'm not surprised, but that's because they have a totally different memory/cache hierarchy.
With CUDA offloading, I don't think it runs otherwise at all.
This is one of the most important improvements in all of AI, because it benefits most AI users by giving them access to more, faster, for the same hardware with little to no tradeoffs.
...for all those users with H100s.
Indeed.
Anyone who is doing anything important or at scale would be at least renting those, or even using an abstracted service that is on top of another service.
Those cost savings allow people to train things for cheaper, causing those cost savings to benefit almost everyone doing important stuff in the space.
... which is currently the most cost-efficient and environment-friendly way to do LLM inference [0].
[0] Small footprint time: before B100 ships; for actually large language models; for prefill only; may cause cancer in California.
spoiler: $xxx,xxx hardware required to run
$25k-$30k
If you need to run it continuously for a year
hoping an expert can answer a few Qs I have :)
Is FlashAttention simply a drop-in replacement for the attention operation in an LLM? Can it be used anywhere that an "attention" operation is used? Or does a LLM need to be trained specially to use FA?
How does FA relate to attention strategies like GQA (grouped query attention) or sliding-window attention? Are they orthogonal concepts? Or you need a specific FA implementation for each strategy?
Recently llama.cpp added flash attention support - does this just mean they started consuming a flash attention-provided CUDA kernel or something?
lastly, in this post, they compare FlashAttention to Triton. I thought Triton was like an abstraction layer? Couldn't FA be implemented in Triton? I just don't really get what it means to say "FlashAttention vs. Triton".
Is FlashAttention simply a drop-in replacement for the attention operation in an LLM? Can it be used anywhere that an "attention" operation is used? Or does a LLM need to be trained specially to use FA?
Yes
How does FA relate to attention strategies like GQA (grouped query attention) or sliding-window attention? Are they orthogonal concepts? Or you need a specific FA implementation for each strategy?
Flash Attention is a way of calculating the Softmax(QK^T)V part of attention, whereas GQA is a way of calculating the Q, K, and V matricies. Sliding window attention (less sure about this, there are a bunch of windowed attention techniques) change the attention mask (the thing that controls which queries can attend to which keys).
Recently llama.cpp added flash attention support - does this just mean they started consuming a flash attention-provided CUDA kernel or something?
I don't use llama.cpp but that sounds about right.
lastly, in this post, they compare FlashAttention to Triton. I thought Triton was like an abstraction layer? Couldn't FA be implemented in Triton? I just don't really get what it means to say "FlashAttention vs. Triton".
They're talking about a previous Flash Attention implementation written in Triton.
1) Pretty much, it's mathematically equivalent. The only software issues are things like managing dependency versions and data formats in-memory, but Flash Attention 2 is already built into HuggingFace and other popular libraries. Flash Attention 3 probably will be soon, although it requires an H100 GPU to run
2) Flash Attention 2 added support for GQA in past version updates:
https://github.com/Dao-AILab/flash-attention
3) They're comparing this implementation of Flash Attention (which is written in raw CUDA C++) to the Triton implementation of a similar algorithm (which is written in Triton): https://triton-lang.org/main/getting-started/tutorials/06-fu...
The code has a comment which seems to hint that Tri Dao was working on FA3 as early as April 2022, the month after Hopper/H100 was announced. I find it mildly curious that over 2 years has elapsed before the code was released today. Perhaps it’s because now there’s better solutions in the pipeline?
Tri’s publication history has been leaning toward SSM and Mamba style architectures recently. Unlike Flash Attention which has quadratic time complexity wrt sequence length, these latest algorithms are subquadratic. Thus they do much less computation, instead of just doing it more efficiently a la Flash Attention.
Dao and Gu published a really long paper this year which demonstrated (among other things) how Mamba/SSM can be formulated such that it’s amenable to acceleration using the same hardware primitives that Transformers benefit from.
Until the strong exponential hypothesis is (dis-)proven, the quadratic cost is required or you have to give something up. Just the cost of exhaustive search.
As (dis-)proving SETH will resolve the P vs NP problem, I wouldn't hold my breath.
The question is if a particular use case can accept those costs.
I am wondering why flash attention is like 5x slower with variable masking than without it? Lack of good masking support almost zeros out the optimizations
Where are you seeing these benchmarks?
I was wondering... this post mentions that ops like sigmoid are very slow.
A lot of modern LLMs use activation functions with sigmoid or soft max like SiLU, Swish, and SOLU.
Does Relu take less of a performance hit, and if so, maybe it'd be better to go back to good old relu?
Relu is literally just a linear function that gets clamped to zero at some point, so yes, it's much less computationally intensive than anything involving an exponential function. But I doubt you would get competitive results using such a simple activation.
TMA (Tensor Memory Accelerator). This is a special hardware unit that accelerates the transfer of data between global memory and shared memory, taking care of all index calculation and out-of-bound predication. This frees up registers, which is a valuable resource to increase tile size and efficiency.
My understanding was that while it frees up registers it more importantly lets the hardware handle address generation, which can become a bottleneck as other operations around it become faster.
FlashAttention's algorithmic improvements is mostly just splitting/combining the softmax part of attention, and is itself not totally novel. The overwhelming contribution is implementing that, and all its fiddly pieces, efficiently on Nvidia hardware.
Clarifying:
Given the question: "How much is the flash attention algorithm tied to the hardware?"
The answer is 0.
ex. you can find generic flash attention recently added in llama.cpp and ONNX (MS needed it for Phi-3, needed for Recall).
On the side, novelty, I have no direct knowledge on, IMHO, asking that question would devolve the way novelty arguments do in any field: there's always someone else who can claim they did 80% of $X via $X-1, therefore, $X is by and large not novel. Ad infinitum.
I think the right analogy for FA is high-quality cache-aware BLAS kernel implementations. The algorithm(s) is (are) clever and (as you note) completely independent of hardware. However, a hardware-naive implementation is approximately worthless. Most of the value of MKL, or Accelerate, or FA is in the careful matching of the parameters and implementation of the algorithm to the capabilities of hardware it's going run on.
I definitely don't mean to take away from Tri/FA by mentioning novelty - I'm just repeating from paper, which refers back to algebraic aggregates[0] in its discussion of their tiled softmax.
[0]: https://web.stanford.edu/class/cs345d-01/rl/olap.pdf
This isn’t true when there is one vendor that’s 90% of the market and 2 maybe 3 generations of hardware to consider. Support A100, H100 and you are supporting most of the current market.
Supporting A100 and H100 is the opposite of being hardware naive, though.
To clarify further, flash attention is explicitly targeting a compute engine with separate MMA and "scalar" vector execution units that allow post-processing the MMA outputs without involving memory bandwidth (though arithmetic intensity, especially relative between the MMA and the "scalar" instructions, is of concern), with a substantial amount of manually-managed L1D$ to use as sub-matrix accumulator, and a linear-in-context-length amount of "VRAM" that requires sensible arithmetic intensity to avoid being a bandwidth bottleneck (iirc in the hundreds when counting the scalar multiplies hiding in the MMA instructions).
This v3 with async might for once be so tied to Hopper that it's not trivially portable to another platform that has the mentioned hardware blocks (AFAIK every AMD GCN card that can do compute shaders would qualify, though they do lack a specialized MMA unit).
The original FA, almost none.
For the latest versions depends on your abstraction, ThunderKittens[0] provides about the same speed up over FA2 (1.3x-2x%) as the article but relatively universal across GPUs. For any new hardware there may be hardware specific features that make it edge out more performance; usually vendors will adopt any new features that seems to beat them, but you do get fragmented API/libraries (which is already true for CUDA).
[0]: https://hazyresearch.stanford.edu/blog/2024-05-12-tk
What do you mean by "relatively universal"? This is Cuda only [0] with a promise of a rocm backend eventually. There's only one project I'm aware of that seriously tries to address the Cuda issue in ml [1].
[0] https://github.com/HazyResearch/ThunderKittens?tab=readme-ov...
[1] https://github.com/vosen/ZLUDA
If you read the article I linked they show that it's entirely based on 16x16 matrices (or "tiles") which is fairly standard across gpus.
I mean they're building an API to abstract away some of the SKU-to-SKU differences, but the broader point cuts the other way, I think:
The value is in adapting the implementation (either manually at write-time or programmatically at run-time) to the specifics of the hardware.
Also, great line:
To add to the discussion, from a practical perspective, AMD hardware totally sucks and yet to have proper implementation with flash-attention-2. ROCm is moving to usable slowly, but not close to being even comparable with cuda.
Whi os it so hard to port FA2 to the m1300 instinct?
https://github.com/kailums/flash-attention-rocm
There are a bunch of good answers, but I wanted to succinctly say "practically, quite a bit". Here's a good little rabbit-hole example:
Karpathy's nanoGPT calling flash attention by checking if torch.nn.functional.scaled_dot_product_attention exists
Looking at the docs, in reality, most of the time you want this to call out to FA2 which optimizes the kernals on the device to split ops on the Softmax of the triangular matrix as well as reduce moving unnecessary batches of floating point numbers back and forth from the GPU to the CPU.
The paper for FA2 almost entirely considers itself through the hardware it's running on.
Conceptually, just a bit, practically (in terms of implementation), a lot. The standard python implementation internally compiles a kernel for your specific hardware.