I'm getting to a point where I'm implementing a lot of different DL models for my research (in PyTorch) and it's obvious to me there's a lot of boiler plate (creating the data loaders, forward pass, backward pass, val (.eval) pass, collecting statistics, saving checkpoints, etc. So like a good software engineer I try to factor these "pipelines" into reusable components (a train step, a val step, etc.). The problem is every single time I feel like I have a good factorization I find some quirky model or training trick edge case (e.g. need to do something with the optimizer at some point in the training process).
Any advice on either how to factor this or some framework that does it already for me? I've looked at fastai's callback model and initially tried to roll my own version but it ends being up being so brittle with very leaky function interfaces (passing boolean flags and branching on reflection and etc).
My kingdom for a monad!
I think Pytorch Lightning might be worth taking a look at: https://github.com/williamFalcon/pytorch-lightning