Thank you for all the nice and constructive comments!
For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...
The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.
Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:
def lm_loss(params, inputs, n_head) -> float:
x, y = inputs[:-1], inputs[1:]
output = gpt(x, **params, n_head=n_head)
loss = np.mean(-np.log(output[y]))
return loss
grads = jax.grad(lm_loss)(params, inputs, n_head)
"hackable" and "simple yet complete technical introduction"
Music to my ears, well done and don't worry too much about the negative comments! They'll come out for anything you do I think.
I saw a tweet from someone the other day talking about how they massively increased their training speed by changing part of their architecture to have dimensions that were a factor of 64 rather than a prime-like kind of number.
One of the comments below it? ~"Seems very architecture specific."
lol.
So don't sweat it! <3 Great work and thanks for putting yourself out there, super job! :D :D :D :D :)))))) <3 :D :D :fireworks:
This is beautiful. Having worked with everything from nanoGPT to Megatron, sitting down and reading through picoGPT.py was clear and refreshing with just the essential details. Nothing left to add, nothing left to take away: perfection.
If you haven't tried cuNumeric [1], you really ought to. It's a drop-in NumPy wrapper for distributed GPU acceleration. Would be interesting to see if it works for this.
The problem with drop-in replacements between CPU and GPU code is that performance GPU code requires rethinking the dataflow often -- so even if the code itself is a drop-in, the "make it good" part still requires some rewriting.
I'd be curious how that library compares to other numeric python GPU libraries
I want to commend you for one of the best written introductions in this space that I've seen, especially the excellent use of hyperlinking that points to really good resources exactly at the right time !
This article is an absolutely fantastic introduction to GPT models - I think the clearest I've seen anywhere, at least for the first section that talks about generating text and sampling.
Then it got to the training section, which starts "We train a GPT like any other neural network, using gradient descent with respect to some loss function".
It's still good from that point on, but it's not as valuable as a beginner's introduction.
I think FastAI lesson 3 in "Practical Deep Learning for Coders", has one of the most intuitive buildups of gradient descent and loss that I've seen. * Lecture [1] Book Chapter [2]
It doesn't go into the math but I don't think that's a bad thing for beginners.
If you want mathematical, 3blue1brown has a great series of videos [3] on the topic.
For those curious about writing a "gradient descent with respect to some loss function" starting from an empty .py file (and a numpy import, sure), can't recommend enough Harrison "sentdex" Kinsley's videos/book Neural Networks from Scratch in Python [1].
The beginning of Andrew Ng’s machine learning course on coursera does that too, it touches on the math a bit and explains how to imagine gradient descent in 3d space
Didn’t do the full course, but after the first few chapters I was able to write a very basic implementation in raw python (emphasizing here on “very basic”)
there is so much material on deep learning basics these days that I think we can finally skip reintroducing gradient descent in every tutorial, can't we?
The idea of "find in which direction function decreases most quickly and go that direction" is really deep, and its implementation via this cutting-edge mathematical concept of "gradient" also deserves a whole section as well.
On one hand, you can explain it to a 5-year-old: Go in the direction which improves things.
On the other hand, we have more than a half-century of research on sophisticated mathematical methods for doing it well.
The latter isn't really helpful for beginners, and the former is easy to explain. You can't use sophisticated algorithms in either case, for beginners, so you can go with something as dumb as tweak in all directions, and go where it improves most. It will work fine for dummy examples.
This one doesn't use any frameworks. The next book by the author (on GANs) uses PyTorch. The math is relatively easy to follow I think.
Andrew Ng's courses on Coursera can be viewed for free and have sightly more rigorous math, but still okay.
You don't have to understand every mathematical detail, same as you don't need every mathematical detail for 3d graphics. But knowing the basics should be good I think!
That concept is not the easiest to describe succinctly inside a file like this (or -- while we are completely at it, in a Hacker News post like this!), I think (especially as there are various levels of 'beginner' to take into account here). This is considered a very entry level concept (not as an insult -- simply from an information categorization/tagging perspective here :D :)), and I think there might be others who would consider it to be noise if logged in the code or described in the comments/blogpost.
After all, there was a disclaimer that you might have missed up front in the blogpost! "This post assumes familiarity with Python, NumPy, and some basic experience training neural networks." So it is in there! But in all of the firehose of info we get maybe it is not that hard to miss.
However, I'm here to help! Thankfully the concept is not too terribly difficult, I believe.
Effectively, the loss function compresses the task we've described with our labels from our training dataset into our neural network. This includes (ideally, at least), 'all' the information the neural network needs to perform that task well, according to the data we have, at least. If you'd like to know more about the specifics of this, I'd refer you to the original Shannon-Weaver paper on information theory -- Weaver's introduction to the topic is in plain English and accessible to (I believe) nearly anyone off of the street with enough time and energy to think through and parse some of the concepts. Very good stuff! An initial read-through should take no more than half an hour to an hour or so, and should change the way you think about the world if you've not been introduced to the topic before. You can read a scan of the book at a university hosted link here: https://raley.english.ucsb.edu/wp-content/Engl800/Shannon-We...
Using some of the concepts of Shannon's theory, we can see that anything that minimizes an information-theoretic loss function should indeed learn as well those prerequisites to the task at hand (features that identify xyz, features that move information about xyz from place A to B in the neural network, etc). In this case, even though it appears we do not have labels -- we certainly do! We are training on predicting the _next words_ in a sequence, and so thus by consequence humans have already created a very, _very_ richly labeled dataset for free! In this way, getting the data is much easier and the bar to entry for high performance for a neural network is very low -- especially if we want to pivot and 'fine-tune' to other tasks. This is because...to learn the task of predicting the next word, we have to learn tons of other sub-tasks inside of the neural network which overlap with the tasks that we want to perform. And because of the nature of spoken/written language -- to truly perform incredibly well, sometimes we have to learn all of these alternative tasks well enough that little-to-no-finetuning on human-labeled data for this 'secondary' task (for example, question answering) is required! Very cool stuff.
This is a very rough introduction, I have not condensed it as much as it could be and certainly, some of the words are more than they should be. But it's an internet comment so this is probably the most I should put into it for now. I hope this helps set you forward a bit on your journey of neural network explanation! :D :D <3 <3 :)))))))))) :fireworks:
For reference, I'm interested very much in what I refer to as Kolmogorov-minimal explanations (Wikipedia 'Kolmogorov complexity' once you chew through some of that paper if you're interested! I am still very much a student of it, but it is a fun explanation). In fact (though this repo performs several functions), I made https://github.com/tysam-code/hlb-CIFAR10 as beginner-friendly as possible. One does have to make some decisions to keep verbosity down, and I assume a very basic understand of what's happening in neural networks here too.
I have yet to find a good go-to explanation of neural networks as a conceptual intro (I started with Hinton -- love the man but extremely mathematically technical for foundation! D:). Karpathy might have a really good one, I think I saw a zero-to-hero course from him a little while back that seemed really good.
Andrej (practically) got me into deep learning via some of his earlier work, and I really love basically everything that I've seen the man put out. I skimmed the first video of his from this series and it seems pretty darn good, I trust his content. You should take a look! (Github and first video: https://github.com/karpathy/nn-zero-to-hero, https://youtu.be/VMj-3S1tku0)
For reference, he is the person that's made a lot of cool things recently, including his own minimal GPT (https://github.com/karpathy/minGPT), and the much smaller version of it (https://github.com/karpathy/nanoGPT). But of course, since we are in this blog post I would refer you to this 60 line numpy GPT first (A. to keep us on track, B. because I skimmed it and it seemed very helpful! I'd recommend taking a look at outside sources if you're feeling particularly voracious in expanding your knowledge here.)
I hope this helps give you a solid introduction to the basics of this concept, and/or for anyone else reading this, feel free to let me know if you have any technically (or-otherwise) appropriate questions here, many thanks and much love! <3 <3 <3 <3 :DDDDDDDD :)))))))) :)))) :))))
Here is an introduction to gradient descent with back propagation, for Ruby, based on Andrej Karpathy's micrograd: https://github.com/rickhull/backprop
So much criticism in the comments. I appreciated the write-up and the code samples. For some people not in ML like myself it's hard to understand the concept behind GPT and this made it a little bit clearer.
I think this is a factor of putting one's self out there. I've had this happen on ML projects I've put out too, though being hyper-engaged in trying to thoughtfully respond to all (or as many as possible of) the comments section for me has seemed to lower negativity a bit just because it brings up the 'person-in-the-room' effect up to an online audience...at least, so I think! :D
I thought it was a great post and manky kudos to the author for putting themselves out like that! I really appreciated this and any work that does this kind of effort in onboarding people and giving people tools to understand something well really I think has some of the most long-term impact to the field.
Lowering barriers to entry, making resources accessible to all, and decreasing experimentation cycle time I think are some of the most critical components to making any progress at all in the field beyond a basic pittance. Imagine if everyone had easy access to, knowledge about, and rapid experimentation results in things like quantum mechanics, large-algorithm testing, painting arts, musical arts, etc. It would drive things so much further forward at an individual and field-based level so quickly. <3 :)))) :D :D ;D :D :D :))))))))
Karpathy has a bunch of great resources on this front! His minGPT writeup is excellent https://github.com/karpathy/minGPT His more recent project nanoGPT which references this video is a much more capable, but still learning friendly, implementation.
I also learnt a ton from NLPDemystified-https://www.nlpdemystified.org. In fact I used this resource first before attempting Andrej Karpathy's https://karpathy.ai/zero-to-hero.html. I find Nitin's voice soothing and am able to focus more. I also found the pacing good and the course introduces a lots of concepts a beginner level and also points to appropriate resources along the way(spacy for instance). Overall an exciting time to be a total beginner looking to grok NLP concepts.
It turns out that transformers have a learning mechanism similar to autodiff but better since it happens mostly within the single layers as opposed to over the whole graph. I wrote a paper on this recently https://arxiv.org/abs/2302.01834v1. The math is crazy.
Can you explain like I'm 5 why this matters distinctly from how transformers are normally trained with autodiff and what its possible applications are?
The paper speculates that it is analogous to gradient descent and empirically confirms it is similar in behavior, but it is not a rigorous proof of any kind.
The momentum experiment they made also does not seem related. E.g. it just adds past values to V, which extends the effective context length.
Match is a bad word, the don’t match, they are duals. The residual stream aka identity mapping needs to be the identity of the attention mechanism as the attention mechanism learns.
But this is the same for all residual streams, not just those in transformers.
Gradient descent is just how neural networks (including auto-encoders) optimize parameters to minimize the loss function. They do this using derivatives to descend down the slope of the function. Autodiff is one way to compute the derivatives. Maybe we’re saying the same thing.
One reason is that some ML libraries are really slow to import, so you don't want to put them at top-level unless you definitely need them. E.g. if I had just one function that needed to use a tokenizer from the Transformers library, I wouldn't want to eat a 2 second startup cost every time:
In [1]: %time import transformers
CPU times: user 3.21 s, sys: 7.8 s, total: 11 s
Wall time: 1.91 s
I didn't think about lazy loading, I also didn't know they were scoped differently! I thought it was some sort of organisation to keep imports close to usage. Thanks!
The scoping also has some performance advantage: locals are accessed by index in the bytecode, with all name resolution happening at compile-time, but globals require a string lookup in the module dictionary every time they're accessed.
This isn't something that should matter even a little in typical ML code. But in generic Python libraries, there are cases when this kind of micro-optimization can help. Similar tricks include turning methods into pre-bound attributes in __init__ to skip all the descriptor machinery on every call.
Curious, in what cases might this help? The compute would have to be python-bound (not C library-bound); and the frequency of module lookups would have to be in the ballpark of other dictionary lookups. I wonder if cases like this exist in the real world.
The case where I've seen those tricks used with measurable effect was a Python debugger - specifically, its implementation of the sys.settrace callback. Since that gets executed on every line of code, every little bit helps.
(It's much faster if you implement the callback in native code, but then that doesn't work on IronPython, Jython etc.)
Author here. It's a design choice, but there's two reasons I chose to use imports like this:
1) For demonstrative purposes. The title of the post is `A GPT in 60 Lines of NumPy`, I kinda wanted to show "hey it's just numpy, nothing to be scared about!". Also if an import is ONLY used in a single function, I find it visually helps show that "hey, this import is only used in this function" vs when it's at the top of the file you're not really sure when/where and how many times an import is used.
2) Scoping. `load_encoder_hparams_and_params` imports tensorflow, which is really slow to import. When I was testing, I used randomly initialized weights instead of loading the checkpoint which is slower, so I was only making use of the `gpt2` function. If I kept the import at the top level, it would've slowed things down unnecessarily.
I do sometimes - just depends on the context and how often the function(xor library) is going to get called.
Here - they put `import fire` only in the `if __name__ == "__main__":` - that seems reasonable to me as anyone pulling in the library from elsewhere doesn't need the pollution.
Right, I do this with argparse for creating simple CLIs for a module generally intended to be imported and used in another program. argparse has nothing to do with the actual module functions and won't be needed if the module if going to be used in a web app or some other context.
This make even more sense for a non-standard library like fire because you won't even need this dependency if you're going to import the module and write your own interface instead.
The import in main doesn't seem particularly useful in context on a quick read, but considering the line
> utils.py contains the code to download and load the GPT-2 model weights, tokenizer, and hyper-parameters.
it seems possible some downloads are happening on import so does make sense to defer until actually needed, as suggested in sibling comments.
Does that import have side effects? Are we really worried about adding an entry to the imports dict if not? Or put differently, what cases do we actually get a negative effect from just importing at the top?
Oh yeah, imports in Python are not just, like, extending a namespace like in many other languages. They, at runtime, go and run the module's __init__ and can have arbitrary side effects - an entire program can run (although usually shouldn't) just in the import. Imports of large modules often take entire seconds.
It is absolutely worthwhile to avoid unnecessary imports if possible.
I know they _can_ have side-effects, I’ve just never seen a case where it actually mattered, and I have used Python professionally for 10 years. So I’m curious if this is more common in ML libraries or something.
I guess it depends on your definition of "side-effects" but it definitely comes up in common ML packages. For one example, importing `torch` often takes seconds, or tens of seconds, because the import itself needs to determine if it should come up in CUDA mode or not, how many OpenCL/GPU devices you have, how they're configured, etc.
It wouldn't surprise me if the original reason is the pervasive use of jupyter notebooks in ML, which don't adhere to normal python conventions, and are affected by slow imports only when those sections are explicitly evaluated.
Side-effects in imports are, in my opinion, unnecessary, losing some of the benefits of static analysis, running with different parameters during tests, compiling to native code (if those tools exist), slowing things down, and more.
Libraries could have an initializer function and the problem would go away.
Importing another module takes non-zero time and uses non-zero memory, and let's face it: python is not exactly a fast language. Personally I'd appreciate a library author that takes steps to avoid a module load when that module is only used (for example) in some uncommonly-taken code paths.
In some (many?) cases it's probably premature optimization, but it doesn't hurt, so I don't see why anyone would get up in arms over it.
importing is a _runtime_ operation: unless previously imported, the interpreter will go and import that module, executing that modules code. that can take a while. it will also bind a name in the current scope to the modules name, so... that might be considered pollution?
I'm in ML and I would also like an answer to this question.
I've seen a lot of Python people sprinkle imports all over the place in their code. I suspect this is a bad habit learned from too much time working in notebooks where you often have an "oh right, I need XXX library now" and just import it as you need it.
The aggressive aliasing I do get since in DS/ML work it's very common to have the same function do slightly different things depending on the library (standard deviation between numpy and pandas is a good example)
But I personally like all of my imports at the top so I know what this code I'm about to read is going to be doing. I do seem to be in the minority in this (and would be glad to be correct if I'm make some major error).
I often end up having to inline imports, because python doesn't support circular imports.
Of course, "don't do circular imports". But if my Orders model has OrderLines, and my OrderLines points to their Order, it's damn hard to avoid without putting everything in one huge file..
Ha, if only! I've been the one to introduce this at the last three jobs I've had, two of which had hundreds of engineers and plenty of python code before I got there.
"Best practices" are incredibly unevenly distributed, and I suspect this is only more true for data/ML-heavy python code.
New (v5) isort doesn't move imports to the top of the file anymore, at least not by default. There is a flag to retain the old behavior, but even then I don't think it will move imports from, say, inside a function body to the top of the module.
Scope-dependent imports. What if a package is just required for that particular function, and once that function is done, the imported package is no longer required?
Another reason (besides the ones already mentioned in the other comments) is that some imports might only be available on certain operating systems or architectures. I once wrote heavily optimized ML code for Nvidia Jetson Nano devices but I still wanted to be able to test the overall application (the non-Nvidia-specific code) on my laptop or in pipelines.
Why not- possible to get an importerror if you make a mistake in the import statement. This kind of error should happen as early as possible and you won't expect it to happen during a random function call
If you're writing in a dynamically-typed, interpreted language like python, I think mistyping an import inside a function is really the least of your concerns when it comes to mistyping things.
Lazy loading, avoiding pollution of symbols in the root scope, avoiding re-exports of symbols in the root scope, self-documenting code ("this function uses these libraries"), portable coding (sometimes desirable), etc.
1) Circular dependencies (and you don't want your house of cards falling down if your IDE/isort decides to reorder a few things); 2) (slow/expensive) expressions that are evaluated on import; 3) startup time required for the module loader to resolve everything at start.
Reminds me the scene from Westworld where they explain their failed prototypes of the human mind with millions of lines of code. The version that finally worked was only a few dozen.
How powerful/heavy it is? Some time ago here was a post about implementing a GPT on a very constrained computer (under a gigabyte of RAM, some old CPU, no GPU (?)) as opposed to an ordinary kind of GPT requiring terabytes of RAM.
I immediately thought it would be nice to do something in the middle: taking full advantage of a reasonably modern multicore CPU with AVX support, a humble yet again reasonably modern OpenCL-capable GPU and some 32 Gigabytes of RAM.
Number of params is the number of weights. Basically the number of learnable variables.
Number of tokens is how many tokens it saw during training.
Vocab size is the number of distinct tokens.
The relationship between params/tokens/compute power is something people have studied a good deal and how it affects model performance. https://arxiv.org/pdf/2203.15556.pdf
If I maintain an open source project, could I build a doc page using a small GPT allowing users to query FAQ and common methods using natural language?
You can build anything using numpy. You can build a supercomputer out of duct tape if you want to.
Spinning up a db, serving an API, doing natural language processing, ... whatever you want.
That said, there are niche solutions that do these things well and that can save you a lot of work. There are also frameworks (such as Pyntango/Jupyter/Nbspectrum) that lets you spin up a simple API quickly.
I think the completeness and self-contained-ness more than offsets the limited scope. One of the problems in the ML field is rapidly multiplying logistical complexity, and I appreciate an example that is (somewhat) functional but simple enough to fit on a postcard and using very basic components.
For someone not familiar with jax, if I do the suggested replacement. What'd be the little extra code to make it do the backward pass?
Or is it all automatic and we literally would not need extra lines of code?
Backprop is just an implementation detail when doing automatic differentiation, basically setting up how you would apply the chain rule to your problem.
JAX is able to differentiate arbitrary python code (so long as it uses JAX for the numeric stuff) automatically so the backprop is abstracted away.
If you have the forward model written, to train it all you have to do with wrap it in whatever loss function you want, and the use JAX's `grad` with respect to the model parameters and you can use that to find the optimum using your favorite gradient optimization algorithm.
This is why JAX is so awesome. Differentiable programming means you only have to think about problems in terms of the forward pass and then you can trivially get the derivative of that function without having to worry about the implementation details.
I haven't heard about JAX before, but been tinkering in pytorch. Would I also be able to switch the use of np arrays here to torch, and then do .backwards() and get kinda the same benefits of JAX, or how does it differ in this regard?
"Of course, you need a sufficiently large model to be able to learn from all this data, which is why GPT-3 is 175 billion parameters and probably cost between $1m-10m in compute cost to train.[2]"
So, perhaps better title would be "GPT in 60 Lines of Numpy (and $1m-$10m)"
And it will be even more expensive to train it again on larger amounts of data and with a model with 10 times more parameters.
Only Big Tech giants like Microsoft, Google, etc can afford to foot the bill and throw away millions into training LLMs, whilst we celebrate and hype about ChatGPT and LLMs getting bigger and significantly more expensive to train when they get confused, hallucinate over silly inputs and confidently generate bullshit.
That can't be a good thing. OpenAI's ClosedAI model needs to be disrupted like how Stable Diffusion challenged DALLE-2 with an open source AI model.
I disagree, I run a small tech company that has a group that's been experimenting with stable diffusion and we noticed that an extreme version of the Pareto Principle applies here as well where you can get ~90% of the benefits for like 5% of the cost, combined with the fact that computing power is continuously getting cheaper.
Based on that groups success, they've recently proposed a mini project inspired by GPT that I am considering funding; the data its trained on is all publicly available for free, and most it comes from Common Crawl. I suspect that it will also yield similar results, where you can tailor your own version of GPT and get reasonably good models for a fraction of the price as well. We're no where close to the scale of Big Tech giants, but I've noticed for the better part of 15 years that small companies can actually derive a great deal of the benefits that larger companies have for a fraction of the cost if they play it smart and keep things tight.
This is happening already. The trick is to run a search against an existing search engine, then copy and paste the search results into the language model and ask it to answer questions based on what you provide it.
A small difference between the pattern you describe and the one of the inquiry is where responsibility lies for retrieving and incorporating the augmentation. You describe the pattern where an orchestration layer sits in front of the model, performs the retrieval, and then determines how to serve that information down to the model. The inquiry asks about whether the AI/model itself can perform the retrieval and incorporation function.
It’s a small difference, perhaps, but with some significance since the retrieval and incorporation occurring outside the model has a different set of trade offs. I’m not specifically aware of any work where model architectures are being extended to perform this function directly, but I am keen to learn of such efforts.
Yes, check out LangChain [0]. It enables you to wire together LLMs with other knowledge sources or even other LLMs. For example, you can use it to hook GPT-3 up to WolframAlpha. I’m sure you could pretty easily add a way for it to communicate with a human expert, too.
If an expert write a long test and you and "in summary: " at the end, the model will complete with something approximating truth (depend on size of model, training, etc)
Humains do a similar things. We have a model in our head of the subject discussed and we can summarize, but we will forget some parts, make errors, etc. GPT is very similar.
It is! You can specify on its prompt that it should "request additional info via search query, using the following syntax: [[search terms here]], before coming to a final conclusion" then you integrate it with a traditional knowledge base textual look up, and run it again with that information concatenated
Stable Diffusion could do it because the task turned out to be amenable to reasonably small models. But there's no evidence of that being the case with GPT.
That said, other organizations that can afford to foot the bill for it are the governments. This is hardly ideal, since such models will also come with plenty of strings attached - indeed, probably more than the private ones - but at least these policies are somewhat checked by democratic mechanisms.
Long-term I think the demand for more AI compute power will lead to much more investment in GPU design and manufacture, driving the prices down. Since the underlying tech itself is well-understood, I fully expect to see the day when one can train and run a customized GPT-3 instance for one's private use, although the major players will likely be far ahead by then.
I saw this [1] presentation where they use scheme to train GPT on a single consumer GPU. I've had no luck finding the 'scorch' compiler they mentioned in the video.
Perhaps I’m missing your point, but isn’t that what they do with their API right now? You pay for text completions, and can fine-tune their model with your data.
Of course, if they leaked the model weight’s and a local inference binary for it they would lose the ability to charge for it. Clones with the weights would crop up all over the place.
Anger… fear… aggression. The dark side are they. Easily they flow, quick to join you in a fight. If once you start down the dark path, forever will it dominate your destiny, consume you it will, as it did Obi-Wan’s apprentice. -- Master Yoda, Return of the Jedi
also for curious mind here is a more authentic tutorial on building gpt (precursor to gpt3) by andrej karpathy https://www.youtube.com/watch?v=kCc8FmEb1nY
My point is if you want to spend time, spend on authentic material not some bogus material.
Since most models require little code compared to big software projects, why not use c++ or any other compiled language directly.
Python with it’s magic functions, shortcuts is just hiding too much complexity which can result in bug performance issues. Plus code is more hard to maintain
> Python with it’s magic functions, shortcuts is just hiding too much complexity
One counterpoint would be that verbosity, especially in the heavy syntax style of languages such as C++, distracts the reader and helps bugs hide in plain sight. For a silly example, imagine trying to read and verify the correctness of an academic paper from its uncompiled LaTeX source.
A lot of AI (not a huge amount but more than you'd think) people can't code in any sense that would get them a job at a normal software company, Python is easy and fast enough to last until the model is obsolete.
Thank you for all the nice and constructive comments!
For clarity, this is ONLY the forward pass of the model. There's no training code, batching, kv cache for efficiency, GPU support, etc ...
The goal here was to provide a simple yet complete technical introduction to the GPT as an educational tool. Tried to make the first two sections something any programmer can understand, but yeah, beyond that you're gonna need to know some deep learning.
Btw, I tried to make the implementation as hackable as possible. For example, if you change the import from `import numpy as np` to `import jax.numpy as np`, the code becomes end-to-end differentiable:
You can even support batching with `jax.vmap` (https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.h...): Of course, with JAX comes in-built GPU and even TPU support!As far as training code and KV Cache for inference efficiency, I leave that as an exercise for the reader lol