Currently the user can run many independent chains of sgmcmc by utilizing torch.vmap themselves. However, torch.vmap doesn't compose with a lot of existing pytorch code. It might be nice to allow a boolean flag like batched or batched_param or parallel that allows the user to define the log_posterior that takes on input and output an additional batch dimension. This way parallel sgmcmc could be run without needed torch.vmap
Currently the user can run many independent chains of
sgmcmcby utilizingtorch.vmapthemselves. However,torch.vmapdoesn't compose with a lot of existing pytorch code. It might be nice to allow a boolean flag likebatchedorbatched_paramorparallelthat allows the user to define thelog_posteriorthat takes on input and output an additional batch dimension. This way parallelsgmcmccould be run without neededtorch.vmap