RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)


am getting this error while train the model, could anybody tell me how can i resolve this issue? following is the code

class LinearNoiseScheduler:
r"“”
Class for the linear noise scheduler that is used in DDPM.
“”"
def init(self, num_timesteps, beta_start, beta_end):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end

    self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
    self.alphas = 1. - self.betas
    self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
    self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
    self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
    
def add_noise(self, original, noise, t):
    r"""
    Forward method for diffusion
    :param original: Image on which noise is to be applied
    :param noise: Random Noise Tensor (from normal dist)
    :param t: timestep of the forward process of shape -> (B,)
    :return:
    """
    original_shape = original.shape
    batch_size = original_shape[0]
    
    sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod[t].reshape(batch_size)
    sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t].reshape(batch_size)
    
    # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
    for _ in range(len(original_shape)-1):
        sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
    for _ in range(len(original_shape)-1):
        sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

    # Apply and Return Forward process equation
    return (sqrt_alpha_cum_prod.to(original.device) * original
            + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise)
    
def sample_prev_timestep(self, xt, noise_pred, t):
    r"""
        Use the noise prediction by model to get
        xt-1 using xt and the nosie predicted
    :param xt: current timestep sample
    :param noise_pred: model noise prediction
    :param t: current timestep we are at
    :return:
    """
    x0 = (xt - (self.sqrt_one_minus_alpha_cum_prod[t] * noise_pred)) / torch.sqrt(self.alpha_cum_prod[t])
    x0 = torch.clamp(x0, -1., 1.)

    mean = xt - ((self.betas[t])*noise_pred)/(self.sqrt_one_minus_alpha_cum_prod[t])
    mean = mean / torch.sqrt(self.alphas[t])
    
    if t == 0:
        return mean, mean
    else:
        variance = (1-self.alpha_cum_prod[t-1]) / (1.0 - self.alpha_cum_prod[t])
        variance = variance * self.betas[t]
        sigma = variance ** 0.5
        z = torch.randn(xt.shape).to(xt.device)
        
        # OR
        # variance = self.betas[t]
        # sigma = variance ** 0.5
        # z = torch.randn(xt.shape).to(xt.device)
        return mean + sigma*z, x0

Based on the error message it seems t might be on the wrong device in:

sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod[t]

Could you check its .device attribute and move it to the same one self.sqrt_alpha_cum_prod uses?

1 Like