
Google/Trax – Understand and explore advanced deep learning - Bella-Xiang
https://github.com/google/trax
======
codingslave
Word is that internally at Google, among a few teams, and then also
externally, Trax/Jax are putting up real competition to Tensorflow. Some teams
have moved off of tensorflow entirely. Combined with the better research
capabilities of PyTorch, the future of tensorflow is not bright. Given that,
Tensorflow still provides the highest performance with regards to production
usage, and has tons of legacy code strewn throughout the web.

I would argue that this is not the fault of Tensorflow, but rather the hazard
of being the first implementation in an extremely complex space. Seems like
usually there needs to be some sacrificial lamb in software domains. Somewhat
like Map/Reduce was quickly replaced by Spark, which has no real competitors.

~~~
logicchains
>I would argue that this is not the fault of Tensorflow, but rather the hazard
of being the first implementation in an extremely complex space. Seems like
usually there needs to be some sacrificial lamb in software domains.

I'd agree if Google didn't have a history of building things with (arguably)
unnecessarily complex APIs, like Angular1. I remember when Angular and React
were new, seeing an Angular "cheatsheat" that was around 14 pages; the
equivalent React cheatsheet was only 2 pages. Now, I do love the idea of
Tensorflow 1, essentially a functional DSL to explicitly construct computation
graphs, but Google's implementation of that idea was suboptimal: hard-to-
follow error messages, not intuitive, multiple APIs to do the same thing (and
continual API breakage as new APIs are introduced), difficult to debug. And
even if the graph "compiled" correctly, it could still fail at execution time.

It's like they were building a programming language but lacked anyone with
language design or PL theory background. Which makes sense given that anyone
passionate about language design might prefer to work for somewhere closer to
the cutting edge like Microsoft (C#, F#, Typescript, F*...), Facebook
(Bucklescript, Hack), Apple (Swift) or Mozilla (Rust). Google does have its
own languages, Dart and Go, but they're notable for ignoring and rejecting
respectively cutting-edge PL theory (e.g. not disallowing null pointers, and
in Go's case not even supporting parametric polymorphism). The day-to-day
languages used at Google are also not particularly appealing to a PL
enthusiast: Python, Java and non-modern C++.

Google software often also seems to care more about enforcing their idea of
"best practices" on the user than about user experience. Tensorflow's C++
support is an example of this: it requires using Babel. As Babel doesn't
easily support integration into an existing C++ project, this essentially
means you have to change your whole project over to Babel just to use
Tensorflow C++, which is a huge amount of effort to go to just to use a
library. Especially when it's probably quicker to just rewrite the model in
PyTorch, which provides a simple header file and static library for linking,
the standard way of distributing C++ libraries. PyTorch also provides a nicer
C++ API, because it's not confined to the ancient C++ standards that Google
enforces, so can provide a modern API that's not full of macros (macros in
modern C++ are considered bad practice; should only be used when there's
absolutely no alternative).

To be fair, Google is really good at engineering language runtimes. Dart, Go
and Tensorflow are all impressive works of engineering. They just seem to lack
the organisational DNA for making them really nice to use, which maybe makes
sense given the main source of their revenues is search/AdWords, the success
of which is primarily driven by superior science/engineering. Compared to e.g.
Facebook and Microsoft, that were/are in the business of making pretty things
that users like to use (operating system, word processor, website). Or even
comparing to Apple: people pay a huge premium for Apple phones over Android
phones, in spite of their worse hardware, because of their more appealing
design.

~~~
WnZ39p0Dgydaz1
I would actually argue that the very first version of Tensorflow was great. I
don't quite remember what version number that was, but it was well before 1.0.
It was basically just a better version of Theano with a single way of doing
things and very explicit graphs. I remember I ported all my Theano code
without issues and loved it.

From then on it all went downhill. Lots of duplicate APIs were added and the
code grew increasingly complex and opque because everyone at Google wanted in.
Everyone wanted a slice of the pie to pad their resumes and internal
performance reviews to be able to say "Look, I contributed this to TF! Promote
me!" \- Now it's a typical example of Conway's Law - a big mess that mirrors
how messed up the internal incentive structure at Google is.

~~~
logicchains
Makes sense. I only started using it after it'd already been though tf.nn,
tf.layers, and tf.estimators (which eventually got replaced by tf.keras,
around the time I moved to PyTorch, and the time dropout was broken for two
releases..
[https://github.com/tensorflow/tensorflow/issues/25175](https://github.com/tensorflow/tensorflow/issues/25175)).
I remember though once spending half a day just trying to figure out why a
graph that built correctly kept failing at execution time: I ended up having
to binary-search the code changes because I couldn't find the cause of the
error in the stack trace, something I've only experienced previously in
template heavy C++ code. I don't imagine that situation was much better in
1.0.

~~~
WnZ39p0Dgydaz1
The error message have always been an issue, even with early releases. That
being said, it was much easier to trace errors when there were fewer high-
level functions and you had to build the layers yourself. You fully understood
what the graph looked like, and it was mostly obvious where something is
likely to go wrong.

Sure, it's a bit more work, but in the long run it saves times because it
avoids the situation where you need spend half a day trying to understand
opaque TF code touched and extended by 100 different developers.

EDIT: Also, I LOVE the Github issue you posted. It's such a beautiful example
of complexity cost when a large number of people is working on a project and
there's no longer a single person that fully understands how code interacts
across modules.

~~~
logicchains
I'd love to know what the solution is to prevent that eventually happening to
a project, if one even exists. PyTorch seems to be doing better, but maybe
that's just because it's a younger project. I'd be interested to see a blog
post comparing Facebook's incentive structure for maintaining products to
Google's.

------
zwaps
Is it just me or is there zero explanation to what this actually is?

It somehow "helps" me understand deep learning but its tutorial / doc is one
python notebook with three cells where some nondescript unknown API is called
to train a transformer.

Huh?

~~~
WnZ39p0Dgydaz1
You're not alone. Even as an ML researcher at ex-FAANG I have no idea what
this is. Is it a collection of well-documented models built on top of jax?

I could probably figure out what exactly this is if I spend an hour looking
through the code, but it should be made clear in the README.

It's kind of funny, I think it's completely opaque what this actually is to
99% of HN users, but for some reason it's being upvoted because it has Google
and Deep Learning in the name.

~~~
WnZ39p0Dgydaz1
15min of looking through the code and my best guess right now is that it's a
reimplementation of higher-level primitives, such as optimizers and layers,
found in Tensorflow/Pytorch/etc, but based on a variable backend (you can pick
jax or TF), together with a collection of models and training loops. I think
the idea is that most of the code is simpler and more modular than what you
would find in TF, which makes the models easier to read.

However, I don't yet understand what the use case is, or how it helps you to
"learn" anything.

------
nestorD
Note that, in this space, there is also Flax[0] which is also built on top of
Jax bringing more deep-learning specific primitives (while not trying to be
tensorflow compatible unlike Trax if I understand correctly).

[0]: [https://github.com/google-
research/flax/tree/prerelease](https://github.com/google-
research/flax/tree/prerelease)

------
unityByFreedom
Is this like a layer on top of TensorFlow to make it easier to get started? Is
it meant to compete with PyTorch in that respect?

I wish the title and description were more clear. They make it sound like a
course but it is a library/command-line tool.

~~~
sandGorgon
So Tensorflow 2 is built on top of a Keras API.

Its supposed to be better UX, but Pytorch really is far superior UX.

------
sillysaurusx
I was recently surprised to discover that Jax can't use a TPU's CPU, and that
there are no plans to add this to Jax.
[https://github.com/google/jax/issues/2108#issuecomment-58154...](https://github.com/google/jax/issues/2108#issuecomment-581541862)

A TPU's CPU is _the only reason_ that TPUs are able to get such high
performance on MLPerf benchmarks like imagenet resnet training.
[https://mlperf.org/training-results-0-6](https://mlperf.org/training-
results-0-6)

They do infeed processing (image transforms, etc) on the TPU's CPU. Then the
results are fed to each TPU core.

Without this capability, I don't know how you'd feed the TPUs with data in a
timely fashion. It seems like your input will be starved.

Hopefully they'll bring jax to parity with tensorflow in this regard soon.
Otherwise, given that jax is a serious tensorflow competitor, I'm not sure how
the future of TPUs will play out.

(If it sounds like this is just a minor feature, consider how it would sound
to say "We're selling this car, and it can go fast, but it has no seats." Kind
of a crucial feature of a car.)

Still, I think this is just a passing issue. There's no way that Google is
going to let their TPU fleet languish. Not when they bring in >$1M/yr per TPU
pod commitment.

~~~
yablak
You can use tensorflow's tf.data, I think.

~~~
sillysaurusx
Sort of. The way you do it is to scope operations to tf.device(None), which
selects the TPU's CPU. (I think it's equivalent to using the first device in
sess.list_devices(), which is the CPU.)

You can scope operations in tf.data, sure, but you can also execute arbitrary
training operations. I use this technique to train GPT-2 117M with a 25k
context window, which requires about 300GB memory. Only a TPU's CPU has so
much.

That's why it was surprising to hear Jax can't do that. It's one of the best
features of a TPU.

------
m0zg
Not sure why one would bother with this. This is a less mature version of
PyTorch. And I know there's XLA and stuff, but I've yet to see any major
benefit from that for research in particular. A ton of time in DL frameworks
is spent in the kernels (which in most practical cases means CUDA/cuDNN) which
are hand-optimized far better than anything we'll ever get out of any
optimizer.

~~~
chillee
If you're talking about Jax, there's a couple different reasons to bother for
research

1\. Full numpy compatibility.

2\. More efficient higher order gradients (because of forward mode auto diff).
Naively it's asymptotic improvement, but I believe Pytorch uses some autodiff
tricks to perform higher order gradients with backwards mode, at the cost of a
decently high constant factor.

3\. Some cool transformations like vmap.

4\. Full code gen, which is neat especially for scientific computing purposes.

5\. A neat API for higher order gradients.

2\. and 5. are the most appealing for DL research, 1., 3. and 4. are appealing
for those in the stats/scientific computing communities.

PyTorch is working on all of these, to various degrees of effort, but Jax
currently has an advantage in these points (and may have a fundamental
advantage in design in some).

~~~
m0zg
From a practitioner:

1\. Meh. PyTorch is close enough to not worry about it, and is better in some
places.

2\. Meh. All the methods people use in practice for deep learning in
particular do not use higher order gradients. Most higher order methods are
prohibitively memory expensive, and memory is at a premium in acceleration
hardware (and so is the bus bandwidth - so you can't "swap to RAM"). I do
agree that higher order gradients are the next frontier in optimization though
- current optimizer research seems to have stalled, so people focus on
training with huge batches and stuff like that. Most SOTA models in my field
are trained with SGD+momentum - super primitive stuff. I don't see how Jax
would solve the memory problem though. You still have to store those Hessians
somewhere, at least partially.

3\. Do agree, that's cool if it actually parallelizes nontrivial stuff which
e.g. tf.vectorized_map barfs on. Although in a lot of cases you can
"vectorize" by concatenating input tensors into a higher dimensional tensor.

4\. Meh. Not sure why I'd want that if I already have tracing and JIT.

5\. This is #2

With PyTorch though, you get close enough to Numpy to feel at home in both,
and there's so much code written for it already that you can usually find a
good starting point for your research pretty easily on github and then build
on top of that.

If you need to deploy, there's also tracing and jit, which lets you load and
serve models with libtorch.

I see what you're saying regarding "advantages", I'm just pointing out that
PyTorch might be "good enough" for most people. If I were on that team, I'd
focus on providing comfortable transition from TF 2.x which is a dumpster fire
(with the exception of TensorFlow Lite which is excellent). That, IMO, would
be the only way for this project to achieve mainstream success unless PyTorch
disintegrates over time.

~~~
chillee
I agree from a practitioner standpoint - but you were talking about research
:)

1\. Much of the scientific computing/stats community is stuck in the past.
Many are still using Matlab! As opposed to the CS community, who are used to
learning new frameworks, offering the ability to "import jax.numpy as np" and
having their scripts just run is valuable to that community. As is having an
API that they've only just started to become familiar with (and has way more
documentation about).

2\. Once again, this is true for practitioners, but not research. Hessian
vector products show up in a decent amount of places. For example, if you have
an inner optimization loop (a la. most meta learning approaches or Deep Set
Prediction Networks) you have a Hessian Vector Product! Perhaps not prevalent
in models that practitioners run but definitely something to keep an eye on in
research.

3\. My understanding is that it actually does a pretty decent job. Enough that
it's useful in the prototyping phase.

4\. PyTorch JIT is neat, and is what I meant by Pytorch team is "working" on
it. However, the JIT doesn't do full code gen (thus, significant operator
overhead for say, scalar networks) and has significantly less man hours poured
into compared to XLA.

5\. I was specifically talking about how you call grad on a function to get a
function that returns its gradient. It's a cleaner API than PyTorch's
autograd.

Jax is definitely not meant for deployment or industry usage, and I believe
their developers hope they'll never be pushed along that direction :^)

I definitely agree that PyTorch is "good enough" for most people. However,
among researchers, there's a decent amount of subgroups it could gain favor
in.

You'd be surprised how many papers get submitted to ICML/Neurips that don't
use PyTorch or TensorFlow at all, in favor of raw numpy, C++, or even MatLab!
I think the numbers I had were something about 30% of papers don't use any ML
framework. Jax could easily gain favor in this crowd.

There's also the crowd that cares _a lot_ about higher order gradients. Also,
admittedly a specific subgroup, but growing. Meta learning people care a lot.
So do Neural ODE people. All it takes is for one of these subfields to blow up
for higher order gradients to all of a sudden become a lot more appealing.

And finally, you have Google. Google researchers are never going to use
PyTorch en masse (probably). If researchers at Google want to switch from TF,
their only option is Jax. This is a pretty big subgroup of researchers :)

I definitely agree that Jax has a difficult hill to climb. But, they have a
solid foothold within Google, and several subfields very amenable to their
advantages.

PyTorch seems like the predominant research framework currently, but if any
framework is going to erode their lead, I'd place my bets on Jax.

~~~
m0zg
>> You'd be surprised how many papers get submitted to ICML/Neurips that don't
use PyTorch or TensorFlow at all

I do keep up with literature and I do some applied research as well, so yeah,
I see such things from time to time. The volume of papers is so intense though
that unless there are other redeeming qualities if the paper does not use
frameworks I already know (TF and PyTorch), I ignore it entirely. I wouldn't
say I missed much that could help me in practice. One exception is Leslie
Smith's work on cyclic learning rates and momentum modulation - he did it on
some ridiculous setup, but it works for what I do.

I'm more surprised how many papers are written for tiny little datasets that
you'd never use in practice, especially optimization papers. I mean, come on
guys, I get it it's fast to train on CIFAR or fashion MNIST, but those results
rarely translate to anything practical. And some papers are just plain not
reproducible at all.

>> Google researchers are never going to use PyTorch en masse

IMO they should. It would easily double their productivity, and if Karpathy is
to be believed their skin and eyesight would improve too:
[https://twitter.com/karpathy/status/868178954032513024?lang=...](https://twitter.com/karpathy/status/868178954032513024?lang=en)

>> I'd place my bets on Jax

As an ex-Googler, I'd place my bets in something else TBH. Google projects
that aren't critical to Google's bottom line tend to deteriorate over time.
Just look at TF. I'm not cruel enough to suggest it to my clients anymore,
even though I could charge twice as much (because it would take twice as long
to get the same result).

~~~
chillee
Certainly if you ignore those papers, you'd likely have no issue in practice -
I suspect many of them are about more theoretical concerns. Perhaps I'll take
a look at/post a list tomorrow.

Either way, I believe that our original discussion was on why somebody should
bother. I provided a list of (admittedly) somewhat niche reasons. My personal
opinion is that Jax will stick around, and at the very least, provide some
neat ideas for Pytorch to ... independently come up with :)

>>> I'd place my bets on Jax

Hey hey hey context! Pytorch is currently dominant in research, so who could
supplant it? Anecdotally, since I published my article
([https://thegradient.pub/state-of-ml-
frameworks-2019-pytorch-...](https://thegradient.pub/state-of-ml-
frameworks-2019-pytorch-dominates-research-tensorflow-dominates-industry/))
there has been more momentum towards Pytorch (preferred networks and openAI).

So if not Tensorflow, then who? I think Pytorch represents a local optima and
is "good enough" for most people. So any newcomer framework needs to bring
something new to the table, even if it's niche. I think Jax looks the most
promising.

~~~
m0zg
I'd like to see something based on a proper, high performance, statically
typed programming language s.t. I could have a modicum of certainty that
things would work when someone changes something. With Python, sadly, you
don't know until you run things and hit error conditions dynamically. This is
unacceptable in larger codebases.

~~~
byt143
Even then shape errors would require a dependent type system not found in most
static languages.

~~~
m0zg
There are levels of survival I'm prepared to accept.

------
JPKab
Looking forward to a readme that is properly filled out. Some documentation as
well. Looks promising.

