Beyond the other answers, I’ll point out that pytorch is developing tools that will make doing this work by hand or implementing in a framework much easier. They’re building a native DTensor implementation and testing out SPMD-style distributed models with pipelining. DTensor is in torch.distributed, and the SPMD code is in the repo called Tau under the pytorch org on github.