Hello, i trying to make txt2img stable diffusion generation, but, faced with a problem when i trying to use model.to() and model.eval() it raise error. Here is a code:
import torch
from samplers import sampler_list as sm
import torch.nn as nn
class generate_txt2img(nn.Module):
def __init__(self, model, prompt, guidance_scale, width, height, steps, sampler='dpmsolver++', seed=None, negative_prompt=None):
super(generate_txt2img, self).__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = model.to(device)
model.eval()
if negative_prompt is None:
negative_prompt_hid = ""
negative_prompt = negative_prompt_hid
image = torch.randn(1, 3, height, width, device=device)
if seed is not None:
torch.manual_seed(seed)
if sampler == 'dpmsolver++':
sampler_var = sm.dpm_solver_plus(model=self.model)
elif sampler == 'dpmsolver':
sampler_var = sm.dpm_solver(model=self.model)
with torch.no_grad():
for _ in range(steps):
z = torch.randn_like(image)
positive_loss = self.model(image, z, prompt, guidance_scale)
if negative_prompt:
negative_loss = self.model(image, z, negative_prompt, None, guidance_scale)
loss = positive_loss - negative_loss
else:
loss = positive_loss
self.model.zero_grad()
loss.backward()
image = sampler_var(sampler, image, image.grad)
image.clamp_(0, 1)
self.generated_image = image[0].cpu()
P.S: Model loading from other script which i made thru variable.
The title means this is a python error not specific to PyTorch - the object you call .to() on is None. You’ll have to show use the code that calls generate_txt2img() for us to help.
Also, generate_txt2img() doesn’t need to be an nn.Module - you can use a function:
from download_components import download_sd_base as sd_download
from load_models import load_model as lm
import os
from other_components import model_manager
import time
import torch
from generation_components.txt2img import generate_txt2img
save_path = "base_models"
if not os.path.exists(save_path):
os.makedirs(save_path)
#files_with_extensions = [file for file in os.listdir(save_path) if file.endswith((".ckpt", ".safetensors"))]
if not os.path.isfile(save_path + "/sd-v1-5-pruned.safetensors"):
sd_download.download_model_sd(save_path=save_path, model='sd', model_ext="safetensors")
loaded_model = lm.load_model(model_path=save_path + '/sd-v1-5-pruned.safetensors')
time.sleep(1)
os.system('clear')
generate_txt2img(model=loaded_model,
prompt="A man, sitting on chair", negative_prompt="NSFW",
guidance_scale=10, width=512, height=512, steps=25, seed=2582485, sampler='dpmsolver++')
It seems model from the previous snippet is loaded_model, which is instantiated using lm.load_model(), which is not shown here. You’ll have to understand why that function returns None - it could be due to a missing file at the supplied path.
it sending weights, while i starting script, and, looks like it really load model:
import torch
from safetensors.torch import load_file
def load_model(model_path: str) -> torch.nn.Module:
if ".ckpt" in model_path:
try:
checkpoint = torch.load(model_path)
model = torch.nn.Module()
model.load_state_dict(checkpoint['state_dict'])
print("Model has been successfully loaded.")
return model
except FileNotFoundError:
print(f"Error while loading model: File '{model_path}' not found.")
except Exception as e:
print(f"Error while loading model: {str(e)}")
elif ".safetensors" in model_path:
try:
model = torch.nn.Module()
checkpoint = load_file(model_path)
model.load_state_dict(checkpoint)
print("Model has been successfully loaded.")
return model
except FileNotFoundError:
print(f"Error while loading model: File '{model_path}' not found.")
except Exception as e:
print(f"Error while loading model: {str(e)}")
else:
raise ValueError('Your model is unsupported.\nUse model with ".ckpt" or ".safetensors" extensions.')