Code:
@torch.jit.interface
class placeholder:
def forward(self,x : torch.Tensor, t : torch.Tensor, labels : torch.Tensor) -> torch.Tensor:
pass
def eval(self) -> None:
pass
def parameters(self) -> torch.Tensor:
pass
class Diffuser(nn.Module):
def __init__(self, T : int = 1000):
super(Diffuser, self).__init__()
self.T : int = T
s = 0.008 # Small constant to prevent excessively small betas at t=0
self.beta_schedule = 0.2 * (1 - torch.cos((torch.linspace(0, 3.14159, self.T) - s) / (1 - s)))
self.beta_schedule = self.beta_schedule.clamp(min=1e-4, max=0.02) # Prevent extreme values
self.alpha_schedule = 1.0 - self.beta_schedule
self.alpha_cumprod = torch.cumprod(self.alpha_schedule, dim=0)
def forward(self, x):
pass
@torch.jit.export
def fwd_diffusion(self,imgs : torch.Tensor, t : torch.Tensor):
"""
apply noise to every img in imgs based 9n the respective t value in the t tensor
params:
imgs: tensor like (batch_size,3,H,W) (from the dataloader)
t : tensor like (batch_size,) (random from 0 to T)
output:
noise_x: tensor like(imgs.shape) = imgs += random noise based on t
noise : tensor like(imgs.shape) = the noise that was added to the imgs
"""
noise_x = torch.zeros_like(imgs,device = imgs.device)
noise = torch.randn_like(imgs,device=imgs.device)
for i in range(imgs.shape[0]):
noise_x[i] = torch.sqrt(self.alpha_cumprod[t[i]]) * imgs[i] + torch.sqrt(1 - self.alpha_cumprod[t[i]]) * noise[i]
return noise_x, noise
@torch.jit.export
def DDPM(self,model : placeholder, device : torch.device,H : int = 32, W : int = 32, label : int = 6,random_variance : float = 1.0):
"""
"""
if random_variance<0:
return
else:
random_variance = torch.tensor(random_variance)
model.eval()
out = torch.randn((1,1,H,W)).to(device)
beta = self.beta_schedule
alpha = self.alpha_schedule
alpha_c = self.alpha_cumprod
label = torch.tensor([label]).to(device)
with torch.no_grad():
for t in range(self.T-1,0,-1):
alpha_c_prev= alpha_c[t-1]
c1 = 1.0/torch.sqrt(alpha[t])
c2 = beta[t]/torch.sqrt(1-alpha_c[t])
pred = model.forward(out,torch.tensor([t]).to(device),label)
out = c1 * (out - c2 * pred)
if t > 1:
noise = torch.randn((1,1,H,W)).to(device) * torch.sqrt(beta[t])
variance = beta[t] * (1 - alpha_c_prev)/(1 - alpha_c[t])
sigma = torch.sqrt(variance)
out = out + noise * sigma * random_variance
#mby clamp to -1,1
return out
@torch.jit.export
def DDIM(self, model : placeholder,device : torch.device, H : int = 32, W : int = 32, batch_size : int = 1, label : int = 0, sampling_steps : int = 25, img_type : str = "bw", eta : float = 0.0):
if img_type == "rgb":
channels = 3
else:
channels = 1
timesteps = torch.linspace(0,self.T,sampling_steps)
model.eval()
x_t = torch.randn((batch_size,channels,H,W)).to(device)
out = torch.zeros_like(x_t).to(device)
beta = self.beta_schedule
alpha = self.alpha_schedule
alpha_c = self.alpha_cumprod
label = torch.full((batch_size,),label).to(device)
for i, t in enumerate(reversed(timesteps)):
t_prev = timesteps[i-1] if i>0 else torch.tensor(0).to(out.device)
t_tensor = torch.full((batch_size,),t).to(device)
pred = model.forward(x_t,t_tensor,label)
out = (x_t - torch.sqrt( 1 - alpha[t]) * pred) / torch.sqrt(alpha[t])
sigma = eta * torch.sqrt((1 - alpha[t]) / ( 1 - alpha[t_prev])) * torch.sqrt(1 - alpha[t] / alpha[t_prev])
x_t = torch.sqrt(alpha[t]) * out + torch.sqrt(1 - alpha[t_prev]) * pred
return out
if __name__ == "__main__":
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
diff = torch.jit.script(Diffuser(500))
print(hasattr(diff, 'DDIM'))
print(hasattr(diff, 'DDPM'))
diff = torch.jit.save(diff,current_dir + '/Diffuser.pt')
print(hasattr(diff, 'DDIM'))
print(hasattr(diff, 'DDPM'))
This will give the following error
RuntimeError: outputs_[i]->uses().empty() INTERNAL ASSERT FAILED at “/home/runner/.termux-build/python-torch/src/torch/csrc/jit/ir/ir.cpp”:1307, please report a bug to PyTorch.
If Diffuser inherits from torch.jit.ScriptModel it runs but the saved model will not have the methods (e.g. DDIM)
I don’t know why this is as I’ve used the guide from the docs>jit area
If someone knows whats happening here I’d be happy to know