Initializing trick in Masked Auto Encoder

Hi, currently I’m working with MAE, and got curious about the initializing trick when initialize self.patch_embed.

MAE uses timm.models.vision_transformer’s PatchEmbed, and the PatchEmbed utilizes nn.Conv2d for patchify.

Then the initialization of the PatchEmbed’s conv be like:

# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w =
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

As written in the comment, weights of the conv are intentionally flatten before its initialization.
So why they make conv weights into the shape of nn.Linear’s?
Is there any advantage of doing so?

Thanks in advance.

Flattening the conv weight will make a difference in the initialization and I guess the authors might see an advantage in their training using it. Did you check the paper to see if they give an explanation there?

Here is a small example showing the difference:

conv = nn.Conv2d(3, 16, 3)

# standard conv weigth shape
fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(conv.weight)
print(fan_in, fan_out)
# 27 144

gain = 1.0
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
# 0.10814761408717502

a = math.sqrt(3.0) * std 
# 0.18731716231633877

# flatten
fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(conv.weight.view(conv.weight.size(0), -1))
print(fan_in, fan_out)
# 27 16

gain = 1.0
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
# 0.21566554640687682

a = math.sqrt(3.0) * std 
# 0.37354368381881414

Note that a will be used in tensor.uniform_(-a, a) to initialize the tensor.

@ptrblck Hi, Thanks for the reply.
I check the paper and didn’t find any explanation about it.

Therefore I eventually hypothesize that the purpose is to mimic the behavior of patchify that original ViT does.
Even though the original ViT also uses a conv layer in its patchify, the way it actually has to do is to reshape the image into B x N x P^2*C, and apply linear transformation in dim=-1.

Therefore, I think the MAE flattens the conv weight to make the conv layer behave like a linear layer.

Do you think this makes sense?

Yes, I think your explanation sounds reasonable, but let’s also wait for an answer of the authors in the corresponding issue.

1 Like