Hacker News new | past | comments | ask | show | jobs | submit login
Feedback on NanoDL: A library for building custom transformers from scratch (github.com/hmunachi)
1 point by HMUNACHI 15 days ago | hide | past | favorite | 1 comment



Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:

- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. - An extensive selection of models like Gemma, LlaMa, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. - Data-parallel distributed trainers including RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. - Dataloaders, making the data handling process for Jax/Flax more straightforward and effective. - Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. - GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU. - Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models. - True random number generators in Jax which do not need the verbose code. - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc. - Each model is in a single file with no external dependencies, so the source code can also be easily used. - True random number generators in Jax which do not need the verbose code (examples shown in next sections). - There are experimental features (like MAMBA architecture and RLHF) in the repo which are not available via the package, pending tests.

I appreciate feedback if you have the time, a dev pre-release is available via pip, it is ideal for building models with no more than 1B params. Lots of improvements to make. If feeling generous, you can contribute or leave a star.

Thanks.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: