@jit.export methods disappear after jit.load(pth)

Code:

@torch.jit.interface
class placeholder:
    
    def forward(self,x : torch.Tensor, t : torch.Tensor, labels : torch.Tensor) -> torch.Tensor:
        pass
    
    def eval(self) -> None:
        pass
    
    def parameters(self) -> torch.Tensor:
        pass

class Diffuser(nn.Module):
    def __init__(self, T : int = 1000):
        super(Diffuser, self).__init__()
        self.T : int = T
        s = 0.008  # Small constant to prevent excessively small betas at t=0
        
        self.beta_schedule = 0.2 * (1 - torch.cos((torch.linspace(0, 3.14159, self.T) - s) / (1 - s)))
        self.beta_schedule = self.beta_schedule.clamp(min=1e-4, max=0.02)  # Prevent extreme values
        
        self.alpha_schedule = 1.0 - self.beta_schedule
        self.alpha_cumprod = torch.cumprod(self.alpha_schedule, dim=0)
        
    def forward(self, x):
        pass
        
    @torch.jit.export 
    def fwd_diffusion(self,imgs : torch.Tensor, t : torch.Tensor):
        
        """
        apply noise to every img in imgs based 9n the respective t value in the t tensor
            
        params:
            imgs: tensor like (batch_size,3,H,W) (from the dataloader)
            t : tensor like (batch_size,) (random from 0 to T)
            
        output:
            noise_x: tensor like(imgs.shape) = imgs += random noise based on t
            noise : tensor like(imgs.shape) = the noise that was added to the imgs
        """
        noise_x = torch.zeros_like(imgs,device = imgs.device)
            
        noise = torch.randn_like(imgs,device=imgs.device)
            
        for i in range(imgs.shape[0]):
            
            noise_x[i] = torch.sqrt(self.alpha_cumprod[t[i]]) * imgs[i] + torch.sqrt(1 - self.alpha_cumprod[t[i]]) * noise[i]

        return noise_x, noise
    
    @torch.jit.export     
    def DDPM(self,model : placeholder, device : torch.device,H : int = 32, W : int = 32, label : int = 6,random_variance : float = 1.0):
        """
            
        """
        
        if random_variance<0:
            return
        else:
            random_variance = torch.tensor(random_variance)
        
        model.eval()
        out = torch.randn((1,1,H,W)).to(device)
        

        beta = self.beta_schedule
        alpha = self.alpha_schedule
        alpha_c = self.alpha_cumprod
        
        label = torch.tensor([label]).to(device)
        
        with torch.no_grad():
            for t in range(self.T-1,0,-1):
                
                alpha_c_prev= alpha_c[t-1]
            
                c1 = 1.0/torch.sqrt(alpha[t])
                c2 = beta[t]/torch.sqrt(1-alpha_c[t])
                
                pred = model.forward(out,torch.tensor([t]).to(device),label)
                
                out = c1 * (out - c2 * pred)
                
                if t > 1: 
                    noise = torch.randn((1,1,H,W)).to(device) * torch.sqrt(beta[t])
                    variance = beta[t] * (1 - alpha_c_prev)/(1 - alpha_c[t])
                    sigma = torch.sqrt(variance)
                    out = out + noise * sigma * random_variance
        #mby clamp to -1,1
        return out

    @torch.jit.export
    def DDIM(self, model : placeholder,device : torch.device, H : int = 32, W : int = 32, batch_size : int = 1, label : int = 0, sampling_steps : int = 25, img_type : str = "bw", eta : float = 0.0):
        
        

        if img_type == "rgb":
            channels = 3
        else:
            channels = 1
        
        timesteps = torch.linspace(0,self.T,sampling_steps)
        
        
        model.eval()
        
        
        x_t = torch.randn((batch_size,channels,H,W)).to(device)
        out = torch.zeros_like(x_t).to(device)
        
        beta = self.beta_schedule
        alpha = self.alpha_schedule
        alpha_c = self.alpha_cumprod
        
        
        label = torch.full((batch_size,),label).to(device) 
        
        for i, t in enumerate(reversed(timesteps)):
            t_prev = timesteps[i-1] if i>0 else torch.tensor(0).to(out.device)
            
            t_tensor = torch.full((batch_size,),t).to(device)
            
            pred = model.forward(x_t,t_tensor,label)
            
            out = (x_t - torch.sqrt( 1 - alpha[t]) * pred) / torch.sqrt(alpha[t])
            sigma = eta * torch.sqrt((1 - alpha[t]) / ( 1 - alpha[t_prev])) * torch.sqrt(1 - alpha[t] / alpha[t_prev])
            x_t = torch.sqrt(alpha[t]) * out + torch.sqrt(1 - alpha[t_prev]) * pred
            
        return out
        
if __name__ == "__main__":
    import os 
    current_dir = os.path.dirname(os.path.abspath(__file__))
    
    diff = torch.jit.script(Diffuser(500))
    print(hasattr(diff, 'DDIM'))
    print(hasattr(diff, 'DDPM'))
    diff = torch.jit.save(diff,current_dir + '/Diffuser.pt')
    
    print(hasattr(diff, 'DDIM'))
    print(hasattr(diff, 'DDPM'))

This will give the following error

RuntimeError: outputs_[i]->uses().empty() INTERNAL ASSERT FAILED at “/home/runner/.termux-build/python-torch/src/torch/csrc/jit/ir/ir.cpp”:1307, please report a bug to PyTorch.

If Diffuser inherits from torch.jit.ScriptModel it runs but the saved model will not have the methods (e.g. DDIM)

I don’t know why this is as I’ve used the guide from the docs>jit area

If someone knows whats happening here I’d be happy to know