Autodiff is a place where there is a gulf between Julia and Python, one
that I think can't be bridged well: JuliaDiff is astonishingly flexible and performant.
I linked to the website (which was updated in May, but its contents could do with more work) because it has examples of how well the suite fits together.
I don't know much about Jax. I've seen competent benchmarks showing an order of magnitude benefit for using ReverseDiff from the AutoDiff suite over Autograd, which is what Pytorch uses for reverse-mode autodiff
https://www.juliadiff.org/