Many parts of tensorflow required Python- at least when I worked there a few years ago, it was nearly impossible to compile XLA into a saved model and execute it from pure C++ code.
Not sure on the specific combination, but since everything in Jax is functionally pure it's generally really easy to compose libraries. E.g. I've written code which embedded a flax model inside a haiku model without much effort.
Completely agree. He ends up sounding kinda amateurish but that's only because (unlike most other podcasters) he's willing to ask questions deep inside the domain of the interviewee.
(Amateurish w.r.t. the domain, not as an interviewer I mean)
At the end of the day all the arrays are 1 dimensional and thinking of them as 2 dimensional is just an indexing convenience. A matrix multiply is a bunch of vector dot products in a row. Higher tensor contractions can be built out of lower-dimensional ones, so I don't think it's really fair to say the hardware doesn't support it.
The Einsum notation makes it desirable to formulate your model/layer as multi-dimensional arrays connected by (loosely) named axes, without worrying too much about breaking it down to primitives yourself. Once you get used to it, the terseness is liberating.
Welllll, there seems to at least be some mathematical cheating in that this is representing a non-well-founded set. (Declare each pixel to be the set of subpixels at the next level down defining it, this forms an infinite descending chain).
At a high level it is the right answer to the data center electricity demand problem. Which is that we need to make AI hardware more efficient.
Pragmatically, it doesn't make much sense given that it would take years for this approach to have any real work use cases in a best case scenario. It seems way more likey that efficiency gains in digital chips will happen first making these chips less economically valuable.
Edit: I guess I'm not sure on whether large training runs count as prod or not. They're certainly expensive and mission critical.