Error with upsampling layers

I am trying to implement a 1d diffusion model. I am not able to understand the Upsample method, when I am trying to use that in the class Unet() it is not able to correctly perform the upsampling. The dimension seems to get mistaken. I know that it is something to do with the Upsample method as I have checked it with printing the shapes of the inputs and the blocks, but I am not able to understand this.

Can someone help me this. The problem is most probably in nn.Conv1d(dim, default(dim_out, dim), 3, padding = 1) .

Model:

def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv1d(dim, default(dim_out, dim), 3, padding = 1)
    )

def Downsample(dim, dim_out = None):
    return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1)

# building block modules

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

# model
class Unet1D(nn.Module):
    def __init__(
        self,
        dim,
        inp_dim,
        init_dim = None,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 1,
        self_condition = False,
        resnet_block_groups = 8,
        learned_variance = False,
        learned_sinusoidal_cond = False,
        random_fourier_features = False,
        learned_sinusoidal_dim = 16
    ):
        super().__init__()

        # determine dimensions

        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)
        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv1d(input_channels, init_dim, 7, padding = 3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # time embeddings

        time_dim = dim * 4
        self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features

        if self.random_or_learned_sinusoidal_cond:
            sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
            fourier_dim = learned_sinusoidal_dim + 1
        else:
            sinu_pos_emb = SinusoidalPosEmb(dim)
            fourier_dim = dim

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)


        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1)
            ]))

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)

            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv1d(dim_out, dim_in, 3, padding = 1)
            ]))


    def forward(self, x, time):

        x = x.unsqueeze(1)
        x = self.init_conv(x)
        r = x.clone()
        t = self.time_mlp(time)
        h = []

        for block1, block2, downsample in self.downs:
            h.append(x)
            x = block2(x, t)
            h.append(x)
            x = downsample(x)

            print('downsample',downsample)
            print(x.shape)

        for block1, block2, upsample in self.ups:

            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)
            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = upsample(x)

            print('upsample',upsample)
            print(x.shape)

        return x

for batch in train_loader:

        batch = batch[0]
        t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
        x = unet(batch, t)

Error:

downsample Conv1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 64, 354])
downsample Conv1d(64, 128, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 128, 177])
downsample Conv1d(128, 256, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 256, 88])
downsample Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
torch.Size([200, 512, 88])
upsample Sequential(
  (0): Upsample(scale_factor=2.0, mode='nearest')
  (1): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
)
torch.Size([200, 256, 176])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-33-daab6b8d333a> in <cell line: 1>()
     14 
     15         #z , z_mu, z_var = unet(batch_noisy, t)
---> 16         x = unet(batch_noisy, t)
     17 
     18         predicted_noise = z

1 frames
<ipython-input-22-6c96cb82eefe> in forward(self, x, time)
    124         for block1, block2, upsample in self.ups:
    125 
--> 126             x = torch.cat((x, h.pop()), dim = 1)
    127             x = block1(x, t)
    128             x = torch.cat((x, h.pop()), dim = 1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 176 but got size 177 for tensor number 1 in the list.

More information:

When I change the values of kernel size and padding; then lines then I get the following error: nn.Conv1d(dim, default(dim_out, dim), 4, padding = 2) The first upsampling goes through.

def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv1d(dim, default(dim_out, dim), 4, padding = 2)
    )

Error:

downsample Conv1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 64, 354])
downsample Conv1d(64, 128, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 128, 177])
downsample Conv1d(128, 256, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 256, 88])
downsample Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
torch.Size([200, 512, 88])
upsample Sequential(
  (0): Upsample(scale_factor=2.0, mode='nearest')
  (1): Conv1d(512, 256, kernel_size=(4,), stride=(1,), padding=(2,))
)
torch.Size([200, 256, 177])
upsample Sequential(
  (0): Upsample(scale_factor=2.0, mode='nearest')
  (1): Conv1d(256, 128, kernel_size=(4,), stride=(1,), padding=(2,))
)
torch.Size([200, 128, 355])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-36-daab6b8d333a> in <cell line: 1>()
     14 
     15         #z , z_mu, z_var = unet(batch_noisy, t)
---> 16         x = unet(batch_noisy, t)
     17 
     18         predicted_noise = z

1 frames
<ipython-input-34-c0e422be0aa6> in forward(self, x, time)
    123         for block1, block2, upsample in self.ups:
    124 
--> 125             x = torch.cat((x, h.pop()), dim = 1)
    126             x = block1(x, t)
    127             x = torch.cat((x, h.pop()), dim = 1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 355 but got size 354 for tensor number 1 in the list.

If the problem is in a concat step and you are using an otherwise unmodified model then could it be that you are using an input shape that the model was not designed for?

1 Like

If I’m not mistaken, this is a modification of lucidrain’s UNet model.

The issue in your case appears to be in the size of the inputs. 354 input size. The input size must be divisible by the number of downsamples. If there are 3 downsamples that cut the sequence length in half, then your input size should be divisible by 2^3=8.

So what’s happening is when you start upsampling, the sequence sizes between your skip connection and upsampling get off and can no longer be concatenated.

Please adjust your input size or remove the skip connections.

1 Like

Hey you are actually correct, I am using that model. My indput size is fixed to 708; May you kindly tell me what all I can apply to adjust it?

Also can you elaborate a bit on the skip connections. I understand skip connection from U-net , but I’m not really sure how to remove them here. Thanks a lot!

Yeah, I think so. May you kindly tell me what I can do now? My inputs are fixed to 708; shall I apply a linear layer before giving inputs to my model to adjust them to a correct shape?

It depends what the L dimension of your data corresponds to, if it’s a spatial dimension, then you might look at something like AdaptiveAvgPool1d — PyTorch 2.0 documentation

If it’s a “sequence” or something where it doesn’t sense to blur together or average different positions before the beginning of the model, then you might try a linear layer like you said.

The skip connections are where you have the concat statements. You will also need to then update your channels on the upsample layers to half their current size.

However, removing skip connections is the last thing I’d do, if I were you. You could just cut off 4 on either side of your sequence(i.e. if it’s time sequential, then cut off the older time steps). That will make it 704, and divisible by 16. Or pad both sides with 6.

You could try something like this:

def Upsample(dim, dim_out = None, padding = 1):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv1d(dim, default(dim_out, dim), 3, padding = padding)
    )

def Downsample(dim, dim_out = None, padding=1):
    return nn.Conv1d(dim, default(dim_out, dim), 4, 2, padding)

...

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            pad = 1 if ind!=0 else 2 #<--- new line
            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Downsample(dim_in, dim_out, padding=pad) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1)
            ]))

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            pad = 1 if ind!=3 else 0  #<--- new line
            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Upsample(dim_out, dim_in, padding=pad) if not is_last else  nn.Conv1d(dim_out, dim_in, 3, padding = 1) 
            ]))

Note: I haven’t tested it, and you may need to tweak the upsamples side still, but, ideally, you should get 708 input —> 356 —> 178 —>89 —> 178 —> 354 —> 708.

1 Like

I am getting the following error:

downsample Conv1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 64, 354])
downsample Conv1d(64, 128, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 128, 177])
downsample Conv1d(128, 256, kernel_size=(4,), stride=(2,), padding=(1,))
torch.Size([200, 256, 88])
downsample Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
torch.Size([200, 512, 88])
upsample Sequential(
  (0): Upsample(scale_factor=2.0, mode='nearest')
  (1): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
)
torch.Size([200, 256, 176])


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-48-daab6b8d333a> in <cell line: 1>()
     14 
     15         #z , z_mu, z_var = unet(batch_noisy, t)
---> 16         x = unet(batch_noisy, t)
     17 
     18         predicted_noise = z

1 frames
<ipython-input-43-ce10267ac3d5> in forward(self, x, time)
    125         for block1, block2, upsample in self.ups:
    126 
--> 127             x = torch.cat((x, h.pop()), dim = 1)
    128             x = block1(x, t)
    129             x = torch.cat((x, h.pop()), dim = 1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 176 but got size 177 for tensor number 1 in the list.

Also John, if possible may you kindly explain me the architecture layers of lucidrain’s UNet model. Originally I am using that model only for my model. Something like all the layers sequentially. I see a model description while I am loading the model, but can’t understand much from there.

Unet1D(
  (init_conv): Conv1d(1, 64, kernel_size=(7,), stride=(1,), padding=(3,))
  (time_mlp): Sequential(
    (0): SinusoidalPosEmb()
    (1): Linear(in_features=64, out_features=256, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=256, out_features=256, bias=True)
  )
  (downs): ModuleList(
    (0): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=128, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Conv1d(64, 64, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (1): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=128, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Conv1d(64, 128, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (2): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=256, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Conv1d(128, 256, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (3): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=512, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 256, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 256, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Conv1d(256, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    )
  )
  (ups): ModuleList(
    (0): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=1024, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(768, 512, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 512, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 512, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Conv1d(768, 512, kernel_size=(1,), stride=(1,))
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(512, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Sequential(
        (0): Upsample(scale_factor=2.0, mode='nearest')
        (1): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      )
    )
    (1): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=512, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(384, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 256, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 256, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Conv1d(384, 256, kernel_size=(1,), stride=(1,))
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(256, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Sequential(
        (0): Upsample(scale_factor=2.0, mode='nearest')
        (1): Conv1d(256, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      )
    )
    (2): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=256, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(192, 128, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Conv1d(192, 128, kernel_size=(1,), stride=(1,))
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(128, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Sequential(
        (0): Upsample(scale_factor=2.0, mode='nearest')
        (1): Conv1d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      )
    )
    (3): ModuleList(
      (0-1): 2 x ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=128, bias=True)
        )
        (block1): Block(
          (proj): WeightStandardizedConv2d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (proj): WeightStandardizedConv2d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
          (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
      )
      (2): Residual(
        (fn): PreNorm(
          (fn): LinearAttention(
            (to_qkv): Conv1d(64, 384, kernel_size=(1,), stride=(1,), bias=False)
            (to_out): Sequential(
              (0): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
              (1): LayerNorm()
            )
          )
          (norm): LayerNorm()
        )
      )
      (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    )
  )
  (mid_block1): ResnetBlock(
    (mlp): Sequential(
      (0): SiLU()
      (1): Linear(in_features=256, out_features=1024, bias=True)
    )
    (block1): Block(
      (proj): WeightStandardizedConv2d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 512, eps=1e-05, affine=True)
      (act): SiLU()
    )
    (block2): Block(
      (proj): WeightStandardizedConv2d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 512, eps=1e-05, affine=True)
      (act): SiLU()
    )
    (res_conv): Identity()
  )
  (mid_attn): Residual(
    (fn): PreNorm(
      (fn): Attention(
        (to_qkv): Conv1d(512, 384, kernel_size=(1,), stride=(1,), bias=False)
        (to_out): Conv1d(128, 512, kernel_size=(1,), stride=(1,))
      )
      (norm): LayerNorm()
    )
  )
  (mid_block2): ResnetBlock(
    (mlp): Sequential(
      (0): SiLU()
      (1): Linear(in_features=256, out_features=1024, bias=True)
    )
    (block1): Block(
      (proj): WeightStandardizedConv2d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 512, eps=1e-05, affine=True)
      (act): SiLU()
    )
    (block2): Block(
      (proj): WeightStandardizedConv2d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 512, eps=1e-05, affine=True)
      (act): SiLU()
    )
    (res_conv): Identity()
  )
  (final_res_block): ResnetBlock(
    (mlp): Sequential(
      (0): SiLU()
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (block1): Block(
      (proj): WeightStandardizedConv2d(128, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (act): SiLU()
    )
    (block2): Block(
      (proj): WeightStandardizedConv2d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (act): SiLU()
    )
    (res_conv): Conv1d(128, 64, kernel_size=(1,), stride=(1,))
  )
  (final_conv): Conv1d(64, 1, kernel_size=(1,), stride=(1,))
  (mu): Linear(in_features=704, out_features=704, bias=True)
  (var): Linear(in_features=704, out_features=704, bias=True)
)

I see. Try this instead:

def Upsample(dim, dim_out = None, kernel=3, padding = 1):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv1d(dim, default(dim_out, dim), kernel, padding = padding)
    )

def Downsample(dim, dim_out = None, padding=1):
    return nn.Conv1d(dim, default(dim_out, dim), 4, 2, padding)

...

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            pad = 3 if ind==0 else 1 
            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Downsample(dim_in, dim_out, padding=pad) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1)
            ]))

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            pad = 0 if ind==2 else 1  
            kernel = 5 if ind==2 else 3
            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Upsample(dim_out, dim_in, kernel=kernel, padding=pad) if not is_last else  nn.Conv1d(dim_out, dim_in, 3, padding = pad)
            ]))

In regards to the architecture, skip connections mean that the outputs for the first Conv layer are stored, and then those are concatenated to the later throughput data of similar size during upsampling. That happens at each stage down/up.

Here is a picture of what that looks like in the case of 2dUnets:

The big gray arrows in the center show each skip connection.

What it does is it gives the model the original context during upsampling and the ability to “decide” whether the additional compute layers contributed anything of value or whether a minimally processed output is ideal.

1 Like

Thanks a lot for this note.