Yep! The goal was to implement Gemma in Unsloth to make finetuning faster and use less VRAM, and my reimplementation seems to get different results than the current ones.
Ye it was indeed very gruelling - but very fun!! I used torch.dist everywhere, read ll implementations side by side to compare them, and had to manually inspect losses, plot them etc. It's a bit hard to automate sadly, since new archs cause new issues.
What's your approach for these more subtle numerical bugs?