I am trying to implement a 1d diffusion model. I am not able to understand the Upsample method, when I am trying to use that in the class Unet()
it is not able to correctly perform the upsampling. The dimension seems to get mistaken. I know that it is something to do with the Upsample
method as I have checked it with printing the shapes of the inputs and the blocks, but I am not able to understand this.
Can someone help me this. The problem is most probably in nn.Conv1d(dim, default(dim_out, dim), 3, padding = 1)
.
Model:
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv1d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None):
return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1)
# building block modules
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
# model
class Unet1D(nn.Module):
def __init__(
self,
dim,
inp_dim,
init_dim = None,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 1,
self_condition = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
random_fourier_features = False,
learned_sinusoidal_dim = 16
):
super().__init__()
# determine dimensions
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv1d(input_channels, init_dim, 7, padding = 3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# time embeddings
time_dim = dim * 4
self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1)
]))
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Upsample(dim_out, dim_in) if not is_last else nn.Conv1d(dim_out, dim_in, 3, padding = 1)
]))
def forward(self, x, time):
x = x.unsqueeze(1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, downsample in self.downs:
h.append(x)
x = block2(x, t)
h.append(x)
x = downsample(x)
print('downsample',downsample)
print(x.shape)
for block1, block2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = upsample(x)
print('upsample',upsample)
print(x.shape)
return x
for batch in train_loader:
batch = batch[0]
t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
x = unet(batch, t)
Error:
downsample Conv1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 64, 354])
downsample Conv1d(64, 128, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 128, 177])
downsample Conv1d(128, 256, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 256, 88])
downsample Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
torch.Size([200, 512, 88])
upsample Sequential(
(0): Upsample(scale_factor=2.0, mode='nearest')
(1): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
)
torch.Size([200, 256, 176])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-33-daab6b8d333a> in <cell line: 1>()
14
15 #z , z_mu, z_var = unet(batch_noisy, t)
---> 16 x = unet(batch_noisy, t)
17
18 predicted_noise = z
1 frames
<ipython-input-22-6c96cb82eefe> in forward(self, x, time)
124 for block1, block2, upsample in self.ups:
125
--> 126 x = torch.cat((x, h.pop()), dim = 1)
127 x = block1(x, t)
128 x = torch.cat((x, h.pop()), dim = 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 176 but got size 177 for tensor number 1 in the list.
More information:
When I change the values of kernel size and padding; then lines then I get the following error: nn.Conv1d(dim, default(dim_out, dim), 4, padding = 2)
The first upsampling goes through.
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv1d(dim, default(dim_out, dim), 4, padding = 2)
)
Error:
downsample Conv1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 64, 354])
downsample Conv1d(64, 128, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 128, 177])
downsample Conv1d(128, 256, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 256, 88])
downsample Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
torch.Size([200, 512, 88])
upsample Sequential(
(0): Upsample(scale_factor=2.0, mode='nearest')
(1): Conv1d(512, 256, kernel_size=(4,), stride=(1,), padding=(2,))
)
torch.Size([200, 256, 177])
upsample Sequential(
(0): Upsample(scale_factor=2.0, mode='nearest')
(1): Conv1d(256, 128, kernel_size=(4,), stride=(1,), padding=(2,))
)
torch.Size([200, 128, 355])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-36-daab6b8d333a> in <cell line: 1>()
14
15 #z , z_mu, z_var = unet(batch_noisy, t)
---> 16 x = unet(batch_noisy, t)
17
18 predicted_noise = z
1 frames
<ipython-input-34-c0e422be0aa6> in forward(self, x, time)
123 for block1, block2, upsample in self.ups:
124
--> 125 x = torch.cat((x, h.pop()), dim = 1)
126 x = block1(x, t)
127 x = torch.cat((x, h.pop()), dim = 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 355 but got size 354 for tensor number 1 in the list.