How to make models that contains `log_prob` and needs to create local tensors in `forward` parallelly trainable?

I created a toy model like following and I wanted to wrap it with DistributedDataParallel but I failed due to tensor.device errors (especially those in Distribution):

from torch import nn
from torch.distributions import Distribution, Uniform, Normal
from torch import Tensor, diag, zeros, randn

class Opt(nn.Module):
    def __init__(self):
        self.I = diag(Tensor([1., 1., 1.]))
    def forward(self, x:Tensor, p:Tensor):
        x = self._some_transforms_1(x, p)
        x = self._some_transforms_2(x, p)
        return x

    def _some_transforms_1(self, x:Tensor, p:Tensor):
        trans_matrix = zeros(3, 3)
        trans_matrix[0, 0:2] = p[0:2]
        trans_matrix += self.I
        return x @ trans_matrix

    def _some_transforms_2(self, x:Tensor, p:Tensor):
        trans_matrix = zeros(3, 3)
        trans_matrix[1, 0:2] = p[2:4]
        trans_matrix -= self.I
        return x @ trans_matrix
class Model(nn.Module):
    def __init__(self, opt:Opt, dist_1:Distribution, dist_2:Distribution):
        self.opt = opt
        self.dist_1 = dist_1
        self.dist_2 = dist_2

    def forward(self, x:Tensor, p:Tensor, w:float):
        x = self.opt(x, p)
        lp_1 = self.dist_1.log_prob(p)
        lp_2 = self.dist_2.log_prob(p)
        return x.sum() * lp_1 * lp_2
data = randn(100, 3, 3)
opt = Opt()
dist_1 = Normal(Tensor([0] * 4), Tensor([1] * 4))
dist_2 = Uniform(Tensor([-1] * 4), Tensor([1] * 4))
model = Model(opt, dist_1, dist_2)
for x in data:
    p = dist_2.sample((1,)).squeeze()
    loss = model(x, p, 1.)

I tried to do some operations like adding self.I = nn.Parameter(self.I) and trans_matrix = to make the model parallelly trainable but I have some problem with the distributions:

(self.lower_bound <= value) & (value <= self.upper_bound)

the lower_bound and upper_bound in distributions.constraint._Interval are not automatically moved to cuda devices. Could anyone tell me how to make this model parallelly trainable?