Hacker News new | past | comments | ask | show | jobs | submit login
On AlphaTensor’s new matrix multiplication algorithms (fgiesen.wordpress.com)
279 points by matt_d on Oct 7, 2022 | hide | past | favorite | 76 comments



Always caution someone who wants to follow in this research: When somebody promotes a fast algorithm there is often a catch. In this case the issue is numerical stability. There is a theorem that states that any algorithm for n-by-n matrix-matrix multiplication that is componentwise forward stable (as good as it gets in this situation) much necessarily use n^3 scalar multiplications. The authors will therefore waste their time if they carry out their plans and try to optimize for stability. The standard algorithm has the nice property and no faster algorithm can have this property. The question of fast matrix multiplication was raised recently on mathoverflow.net, see https://mathoverflow.net/q/421304/110176 and the answers given there.


We're talking about algorithms for TPUs, which quietly quantize your float32 matrices to bfloat16 behind your back [0]. This is aimed at a crowd that doesn't care about stability.

[0] https://cloud.google.com/blog/products/ai-machine-learning/b...


There's a big difference between not caring about stability, and being willing to trade precision for better memory bandwidth for an application that doesn't benefit from increased precision. When doing large training jobs on TPUs, stability is paramount! It's true that you have to know more about what you're doing when you reduce bit-depth - the horrors of floating point are harder to ignore, and it's wildly inappropriate for many scientific computations. However the reduction of bit-depth is likely to continue as we seek to make modern models more efficient and economical to train and use.


What does this mean in practice? For ML, we usually don't care if a weight is 0.05 or 0.10 cause we have millions of weights. We do care if one 1.237e+27 instead of 1.237e-3 though.


Numerical errors have the annoying tendency to accumulate if you're not careful. So doing one matrix operation with low precision might be okay, while doing a dozen might completely garble your result.


This is not that relevant for ML. Each gradient pass will re-compute your cost function and the gradients so errors are not likely to accumulate. The main thing is to not make errors big enough that you end up in a completely different part of the parameter space derailing progress which is what the above commenter points out.


It is extremely relevant for ML.

I am familiarizing myself with recurrent neural networks and getting them trained online is a pain - I get NaNs all the time except for very small learning rates that actually prevent my networks to learn anything.

The deeper network is, the more pronounced accumulation of errors in online training is. Add 20-30 fully connected (not highway or residual) layers before softmax and you'll see wonders there, you won't be able to have anything stable.


This isn't true in general. Very specific ML algorithms that were likely developed with years of blood and sweat and tears may have this kind of resiliency, but I've been in the the numerical weeds enough here that I wouldn't bet on even that without a real expert weighing in on it - and I wonder what the tradeoff is if it's true there. It's very easy to have numerical stability issues absolutely crater ML results; been there, done that.

I have some ~15 year old experience with the math behind some of this, but actually none with day-to-day deep learning applications using any of the now-conventional algorithms, so my perspective here is perhaps not that of the most pragmatic user. The status quo may have improved, at least de facto.


I'm not really sure there is evidence for that. In fact, depending on your interpretation of why posits[1] work, we may even have empirical evidence that the opposite is true.

1. https://spectrum.ieee.org/floating-point-numbers-posits-proc...


When building a mcmc sampler I was too lazy to properly code a matrix approximation needed to avoid some mathematical black hole and the corresponding underflow. It was cheaper to just ignore the faulty simulations.

Turns out our results were better than the papers we compared to, both in time and precision.

I am not that familiar with ml, but can't you just ignore those faulty weights?


With MCMC, depending on application, it seems risky to just toss out the NaN/inf results. I'd guess these numerical issues are more likely to occur in certain regions of the state space you're sampling from, so your resulting sample could end up a bit biased. In some cases the bias may be small or otherwise unimportant, so the speed-up and simpler code of filtering NaN/inf results is worth it, but in other cases (like when the MCMC samples feed into some chain of downstream computations) the bias may have sneaky insidious effects.


I didn't think deeply about this back then since my parameter estimates where close/better than the literature I compared to, but now I'm interested in checking the distribution of those NaN/inf. If I recall correctly they were uniformly distributed throughout an adaptive phase.


When people talk about AI taking over the world, a funny image pops up in my head where a robot is trying to enter a frying pan. When you ask it why it's doing that, it says "because I feel like [NaN, NaN, 2.45e24, NaN]", which is a perfectly valid reason.

I'm not at all caught up with the this side of ML but my first instinct is that faulty weights would lead to interpretability issues. The numbers represented by NaN/Inf vastly outnumber the ones within precision range, so interpreting them is much more of a guess.


Weight changes in one neuron can have dramatic and non linear or obviously predictable impact on the performance of a full model.


in numerical analysis 101 you learn not to use algorithms that don't have certain properties and numerical stability is one of them

what good will it do to compute something if its error is unbound?

the issue of the accumulation of roundoff errors is generally speaking unavoidable when it's linear but fortunately they tend to be small


"A considerable group of numerical analysts still believes in the folk “theorem” that fast MM is always numerical unstable, but in actual tests loss of accuracy in fast MM algorithms was limited, and formal proofs of quite reasonable numerical stability of all known fast MM algorithms is available (see [23], [90], [91], [62], and [61])." https://arxiv.org/abs/1804.04102


My concern is that there are not enough people who are qualified to determine if a fast algorithm can be used or not. It feels reckless to include less stable algorithms in a general purpose library when the vast majority of users are mainly concerned with speed and blissfully unaware of the pitfalls of floating-point arithmetic.


My reading of the paper is that the new 4x4 algorithm only works in Z/(2), where there are no issues of roundoff errors. (Z/(2) is the field of integers modulo 2.) The paper seems to say that for real numbers, Strassen is still the best known algorithm for the 4x4 case.

(Disclaimer: googler, I have nothing to do with this research.)


For real numbers the decomposition rank is 49 for both AlphaTensor and Strassen, so they're equivalent – wouldn't really say Strassen is better.


I am not a fan of Google. But this is such a bizzare, deliberately misleading, snarky comment. Lower precision floating point arithmetic is very common in ML training. There's no 'behind your back' going on here.


If there are in fact stability issues, I wonder which is cheaper: using this fancy algorithm or changing bfloat16 to something like bfloat14 and using a more stable matmul.


I don't see how that can be true. Lack of precision is one thing, lack of stability is very different.

Instability leads to divergence from the true answer, and I would expect it to mean super-linear divergence (though I am not an expert in this) which would quickly destroy any meaningful result (=> chaotic behaviour). But I'm not an expert.


In practice this doesn't happen because numerically unstable NNs tend to have bad loss. A simple way to see this is that instability means that the network is highly sensitive to the inputs, which means that the network will give wildly different results for basically the same input which is wrong. Furthermore, if the weights of your NN are such that you are getting dramatic overflow/underflow that prevents it from correctly predicting, that will have a high loss, and the process of training will move towards parameters that don't have these rounding errors blow up.


bfloat16 is stable enough for ML training, which is what TPUs exist for


Be careful: 16 bits was quite a lot of colors around 20 years ago. Now we would laugh at it. Ditto for 640Kb of RAM (who would need more?) etc.

Not trying to be dismissive just saying that... computational limits are limits on what can be done, in the end.


Wrong way to think about the problem. Lowering bits does not mean lowering model capacity. Its the opposite in fact - it allows you to you to fit more parameters.


Well, yes within your constraints. In the end, you are choosing between two aspects. Same as with screens: you could have resolution (1024x768!!) or color (16bits!).

Edit: the term I could not remember is tradeoff.


kind of. if we get computers that are 1000x faster, it just becomes a tradeoff between higher precision or 1000x more parameters. the reason resolution has stopped being pushed is that our eyes have severe diminishing returns. it's not yet known whether brains do.


30 years ago. The time flies...


well, you are right and I grow old, I grow old, I shall wear the bottoms of my trousers rolled…

I stand corrected.


Not always though!


While your point about numerical stability is correct in general, there are no numerical stability issues here and I think this conception, which I've seen in more than one place now, stems from a fundamental misunderstanding of the paper's results. While they _did_ come up with a faster TPU/GPU algorithm too, the primary result is not a fast matmul approximation, it is an exact algorithm comprising of stepwise addition/multiplication operations, and hence is numerically stable and should work for any ring (https://ncatlab.org/nlab/show/ring). AlphaTensor itself does not do the matrix multiplication, it was used to perform an (efficiently pruned) tree search over the space of operations to find an efficient, stable algorithm.


Directly from the paper’s "Discussion" section:

> One important strength of AlphaTensor is its flexibility to support complex stochastic and non-differentiable rewards (from the tensor rank to practical efficiency on specific hardware), in addition to finding algorithms for custom operations in a wide variety of spaces (such as finite fields). We believe this will spur applications of AlphaTensor towards designing algorithms that optimize metrics that we did not consider here, such as numerical stability or energy usage.


Right, but doesn't that mean that it could potentially be used for designing algorithms that have componentwise numerical stability over some kind of floating point standard, but this, by definition being a result over finite fields, should be numerically stable?

(apologies if I misunderstood, I wasn't calling you out specifically but a generalized misconception I've noticed in a lot of other discussions so far)


In practice you don't need to recurse all the way down. 1 level of Strassen is enough to get real speeedups (if speedups are possible at all) and certainly for deep learning, the instability introduced by a single level will not matter.


Do we need forward stability in the componentwise relative sense? This very much depends on the underlying real life application. If we ignore the question of accuracy and stability, then we just might endanger the people who depend on our software.


That's why I specified the application was deep learning.


They’re obviously talking about bfloat on TPU.

The day when a lot of wrong math adds up to a computer drawing a pretty picture. Who would have thought.


Excuse my lack of knowledge here, but what is stability?


Numerical stability is when you do operations on the matrix while still keeping an eye on the overflow and underflow flags, all the way during the algorithm execution.

This means regardless of how big your matrix is, or how big or small your numbers are — or even the relation between them —, you algorithm is going to be stable and accurate. If it can't be (stable), the library must let you know. This is so you can keep the numbers scaled such that they are not too large, nor too small in order to keep them stable for the operations you need to execute.


When working over a finite field there are no stability issues though, since arithmetic is exact.


It will always depend on the application.


While I agree with other commenters and also with the article, that this is not like „boom, DeepMind magically speeds up everything and you have to look in more detail at numerical stability etc depending on your use case, there is still something big in here: in the past algorithms were an almost exclusive product of long and deep thinking of experts. Now we saw that AI can be used for algorithm discovery. This can actually have quite big impact. All tiny improvements can add up, some improvement in sorting, some in matrix multiplication, some in lookups, you get the point. All that can accumulate to business advantages.

So yes I think this is an important first result.


Proofs are a pretty important part of algorithm development. Proofs of correctness as well as proofs of algorithmic complexity bounds. At the moment that doesn't seem like something this type of approach can do much for, though perhaps it could be combined with work on automated theorem proving.


These algorithms are provably correct, and it will only generate provably correct algorithms (as the paper goes into).


but in this case, it's extremely easy to prove that the algorithm that they propose is correct, or am i missing something?


"Correct" means different things in different contexts. In a fairly standard application of matrix multiplication you don't multiply matrices A and B to get AB, instead you have some floating point approximation to A, and some floating point approximation to B, and you want something that you can be confident approximates AB with some bounds. The two most important characteristics of an algorithm are the runtime and the error bounds you can prove.

Someone smart once said getting the wrong answer in time O(1) is very easy.


> Someone smart once said getting the wrong answer in time O(1) is very easy.

Reminds me of a joke about a guy at a job interview:

    "So, what kind of skills do you have?"
    "I can do mental multiplication really fast."
    "Ok, what's 102 times 376?"
    "87843"
    [enters numbers in calculator] "That wasn't even close to correct."
    "Yeah, but it was fast."


Sorry, but this simply does not make any sense.

Algorithmic correctness does not vary in different contexts.

Algorithmic usefulness/applicability does.

You are confusing the two.

Correctness here means that it provably generates a result that meets the definition of correct matrix multiplication.

In this case, they prove that all algorithms generated do (and that the system will actually only generate provably correct algorithms).

Applicability here is whether, when applied to a particular {not-infinite precision computer, use case}, it is viable to use it.

That does not affect whether the algorithm is correct or not, only whether you can use it to achieve a particular result.

If i have a computer with 1 bit of floating point precision, that does not make the algorithms all suddenly incorrect. Within the bounds of the what i can provide (not a lot), they still function exactly as they are supposed to. If i need 75 significant digits on this computer, it simply means that they are not useful for my computer because it cannot generate enough significant digits from them to be useful. That is totally orthogonal to whether the algorithms function as designed.


An algorithm for "matrix multiplication" that assumes infinite precision arithmetic is used is essentially useless (there are some niche uses for matrices with entries in finite fields/rings). An algorithm for matrix multiplication usually comes with more than that, in particular some sort of error bound. Note that these bounds are mathematical, part of the abstract algorithm, and not something that is a property of any particular implementation (your last paragraph suggests you might be confused by this). You want a statement that says if A' is close to A (in some relevant distance measure), and B' is close to B, then your algorithm gives something close to AB.

The correctness you have to prove includes proving that your error bounds are what you say they are.


Again, useless is not the same as incorrect. It is still correct whether it can be used somewhere due to implementation limitations or not.

I'm really unsure how you can possibly argue otherwise.

It's like arguing that a string algorithm is incorrect because it doesn't run fast enough on strings to be usable on any current computer. It's still correct. It's just not usable.

Unlike correctness, useless is very context specific. 100 years from now, an infinite precision arithmetic algorithm may be entirely useful.


I am arguing that a matrix multiplication algorithm, like essentially any other numerical algorithm, consists of (at least) the following two parts

1. The actual steps you have to follow.

2. Some form of error bounds/analysis that tell you how good/bad the output will be given approximate inputs.

In order to prove correctness you have to prove that the procedure gives the error bounds you claim. The error bounds are something you mathematically have to prove.


"In order to prove correctness you have to prove that the procedure gives the error bounds you claim"

You don't get to just add your own requirement for correctness and then force people to prove it?

They claim a specific thing - they prove that thing. That thing suffices to prove that it succeeds at matrix multiplication. You for some reason really just don't like that as far as i can tell, and argue it doesn't suffice for usefulness (which i agree on)

Matrix multiplication, and "essentially any other numerical algorithm", is not defined in terms of the error bounds for correctness. That is just BS. The error bounds depend on implementation factors, and as such, they are totally unrelated to correctness.

Let's take a look: https://en.wikipedia.org/wiki/Matrix_multiplication

I have read the entire definition, nowhere does it refer to error bounds as a requirement for successful matrix multiplication!

The word "error" does not even appear on the page

Since it's wikipedia, I also pulled out my college math books. Same thing.

They prove correctness without any reference to error bounds. Those are accepted proofs.

I don't see a single basic proof that has error bounds as part of correctness.

It, again, wouldn't make any sense, because error bounds depend on implementation factors.

So again, you simply can't add your requirement to correctness just because you like it. They still remain where they should be - usefulness for application.


You seem to be confusing the mathematical definition of matrix multiplication with an explicit algorithm to compute it. The wiki page you linked is almost entirely about the abstract mathematical definition and talks about algorithms only for a single paragraph, which is about computational complexity. If you look at the wiki page for a particular algorithm (e.g. Strassen: https://en.wikipedia.org/wiki/Strassen_algorithm) then of course they talk about stability in comparison to the naive algorithm (although the page could be improved quite a bit).

If your college textbooks do not mention error analysis, conditioning and stability then they are not numerical linear algebra books worthy of the name. Check out a reference like Trefethen and Bau's Numerical Linear Algebra for example. This book has a whole part (out of the 7 parts in the book) talking about conditioning and stability, and these ideas are present throughout other parts as well.

Once more the type of analysis I'm talking about is emphatically not implementation dependent. It is a property of the algorithm itself. For an example of the sort of statement I mean check out theorem 3.1 of this paper: https://arxiv.org/abs/math/0603207. If you disagree with me then please indicate what sort of "implementation factors" appear in the statement of the theorem.


Correctness of a mathematical algorithm is of course defined by whether it meets a mathematical definition. In this case, correctness of matrix multiplication is only defined by whether it meets the mathematical definition of matrix multiplication. That's it. That's the whole thing.

Correctness of all computer science algorithms is defined by whether they meet a particular algorithmic specification. That's literally the definition of correctness:

https://en.wikipedia.org/wiki/Correctness_(computer_science)

"In theoretical computer science, an algorithm is correct with respect to a specification if it behaves as specified"

The specification is the matrix multiplication definition given right on that Wikipedia page. this algorithm meets it. It is therefore correct. (again, "error" does not appear on the page here either). There is no separate, special definition for "correctness (mathematical)" or "correctness (eigenket)". You really seem to want to there to be one, but it ain't there.

You really really really don't want to let this go, but the problem is - nothing, anywhere, agrees with you that correctness and usefulness are the same thing. Nor can you cite any reference, like i just did, to correctness that requires it do anything other than meet the specified mathematical definition. Your papers don't do it, my books don't do it. Nothing does it. Because it's not a thing.

My college textbooks talk about error analysis. I did not claim otherwise. They talk about it in the context of how to make an algorithm useful for a particular purpose, not about correctness. They are not making a ridiculous claim like you are.

I'm not going down this path anymore with you. Believe what you want, the rest of us will continue to not confuse it, and sources that people look at (textbooks, wikipedia, etc) will continue to not lead them astray.

I can only hope that at some point, you too stop trying to do so.


If you look up numerical algorithm in Wikiepedia you will find plenty of discussion of error bounds. Of course matrix multiplication performed on exact numbers doesn't have errors, the calculation is exact. Computers and the algorithms they run do not have that luxury when dealing with approximations.


Of course you will, because those are about how they are implemented on computers with limited precision.

Your point is exactly mine - correctness is usually defined on exact numbers, which does not have error bounds. Usefulness is defined by particular implementation choices and precision choices when implemented on a particular computer.

The entire argument here is (crazily) that you can't prove correctness without error bounds, correctness is always context specifi. Of course you can, and of course it's not. Just like the wikipedia algorithm shows.

That may or may not make it useful for a particular application.


There are lots of methods (e.g. LU) that are used millions of times a day, but don't have good error bounds. Most matrix algorithms used in practice have something resembling a proof of error bounds that everyone ignores because they don't do a good job of describing the error that you actually get when you use the methods.


What does "correctness" mean? It's impossible to talk about correctness without specifying what you are intending to compute. Numerical algorithms frequently have correctness conditions specified in terms of a precision (e.g. compute a floating point approximation to f(x) to within 1eps), and any algorithm that doesn't meet that precision is incorrect with respect to that correctness condition, by definition.

So you can think of this paper as saying "suppose you have a correct multiplication and addition operation on the field of interest. Then this algorithm for multiplying matrices over that field, which is composed of a sequence of those multiply and add operations, computes the correct matrix product." That is a perfectly provable kind of fact, that as you say doesn't stop being the case if you switch computers or something.

But fpmul and fpadd on your computer doesn't satisfy the condition for this proof! Therefore it doesn't apply, except by rough analogy. Then, other people might be interested in a different kind of proof, of a fact more like: "given two matrices of floating point numbers and fpmul and fpadd operations that are within .5eps precision, this sequence of fpmul and fpadd operations provably computes the matrix product such that the eigenvalues of the product are within .5eps of the true values". (edit to add: you could then also prove that the implementations of fpmul and fpadd on your computer satisfy the first condition, or replace them with implementations that do). That is also a well-specified correctness condition, and it is a correctness condition that is not necessarily satisfied by replacing "multiply" and "add" in the first algorithm with "fpmul" and "fpadd". This is what it means for "correctness [to be] context dependent". It doesn't invalidate the first proof, it just means we are interested in a different correctness condition than the first proof proves.

(edit to format, and add: of course different conditions on precision are relevant to different applications. Maybe I need my matrix elements to be computed within some absolute error bound, but my friend needs them computed within a relative error bound. Or I need it to be able to work on subnormal numbers within a certain precision, but my friend only needs it to be precise for numbers between 1 and 2. Different algorithms may satisfy one condition but not the other.)


Thanks for this comment. I think I started in this comment section trying to say something like this, but said it in a bad way, and then got progressively more confused and less clear in my responses as the day went on.

This is much more clear than anything I would have been able to write.


I haven't read the paper, but from the blog post I get the sense that finite fields is one of the main claims of usefulness for this result.

I do agree that "'correct' means different things in different contexts."


It will be interesting to see how well this method will perform on other type of algorithms. Matrix multiplication was quite convenient: thanks to the divide-and-conquer approach you only need to find algorithm for 4x4 matrices (or other small matrices), and there is an easy way to prove correctness of the algorithm


ATLAS tried to auto-tune BLAS quite a while ago -- it works OK (not as good as the hand-tuned libraries with assembly kernels, though).


Is Strassen's algorithm actually used in practice? Oddly enough I find myself doing a lot of matrix multiplication recently. But I am just using cublasCgemm3mStridedBatched from Nvidia's cuBLAS library, and it doesn't appear to be public information how it's implemented. Does anyone know if it's actually using Strassen?

Basically the library described at:

https://developer.nvidia.com/blog/cublas-strided-batched-mat...

I am a bit not-sold-yet on the AlphaTensor stuff because in practice it often seems like shuffling the data around in GPU memory is more expensive than doing the actual multiplications. It takes longer to move values between regular GPU memory and shared memory than it does to do a multiply, right? So all these algorithms that are optimizing the number of arithmetic operations, it isn't even clear to me that they're optimizing the right thing, because they require that you shuffle your data around in weird ways, and they don't generally measure the number of "memory moves" that are needed.

That said, I would be happy to drop in a replacement for cublasCgemm3mStridedBatched and test out if it worked better for me! It doesn't seem like these new AlphaTensor matrix multiplication routines are available as plain old c/c++ libraries yet, though.


It's useful enough for the relatively small class of people dealing with matrices of sizes 8192×8192 and above, moreso to those writing backend libraries for targeted computation architectures that can utilise the various weights for moving blocks of data from input source streams between computation nodes, etc.

You're correct - the gains come from really knowing the computational architecture and using this approach to find tweaks that optimise operations .. where those operations aren't just atomic mults and adds, but include piped multiply-adds and data moves.


There’re many people doing larger matrices, but of course they’re mostly sparse and so this isn’t relevant to them.


Isn't there some heuristic that uses advanced multiplication algorithms if the matrix is big enough? Perhaps also checking for sparseness if that matters.


IIRC, Strassen's algorithm is less stable, so it isn't just a "if the matrix is big enough, go for it" sort of thing, necessarily... although, I've never looked into exactly where the issue shows up.

I wonder how well it parallelizes.


Why would you just use a warp and do the memory moves on the GPU?


That is correct, if I understand you correctly, but that doesn't solve the entire optimization problem. You still have to figure out how exactly to handle tiles and transfers to shared memory. This might be a good page for answering this question:

https://docs.nvidia.com/deeplearning/performance/dl-performa...


I wonder, was/is the ai able to reproduce the best know algorithms we already have? Can it do so reliably?

If this Algo is only good for large matrices, can it reproduce the one for small matrices?


Some good critique on this result: https://twitter.com/cHHillee/status/1577713102434361344

TLDR seems like the baseline strassen implementation they used is questionable wrt how really optimal it is in the first place.


Also note: a few tweets down in this thread, the tweeter says "you're right, I retract my criticism". So maybe not extremely well thought out critique after all. On twitter of all places.


I retracted my specific criticism about not comparing to a regular matmul, but I still keep my criticism about having weak baselines :)


Point well taken. To be clear, I didn't think the tweet exchange was out of line in any way.


Possibly sorting networks could get practical benefit from the Alpha treatment.

I think the largest sorting network proven optimal is for a block size in the teens (13?).

Analogously to matrix multiplication, it matters most where comparisons are much more expensive than swaps.




Join us for AI Startup School this June 16-17 in San Francisco!

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: