Thanks, that's certainly helpful anecdotal evidence.. yeah it seems like there should be an "AllNorm" implementation that covers all cases and is just fast. I was wondering because I'm currently looking at math_group_norm, which was ported from PyTorch/XLA and it results in a really weird decomposition that I'm astonished works at all. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen...
I'm also wondering if the handcoded backward passes are actually "numerically correct", because e.g. epsilon doesn't appear in it at all. Someone worked out the gradients manually for BN here: https://web.archive.org/web/20180826123459/http://cthorey.gi...
You can clearly see epsilon appearing in the output. And of course there's the whole training vs. eval mode thing with BN which GN doesn't have.
I'm also wondering if the handcoded backward passes are actually "numerically correct", because e.g. epsilon doesn't appear in it at all. Someone worked out the gradients manually for BN here: https://web.archive.org/web/20180826123459/http://cthorey.gi...
You can clearly see epsilon appearing in the output. And of course there's the whole training vs. eval mode thing with BN which GN doesn't have.
In any case, thanks again.