How to access weights of each group in a grouped convolution?

a convolution layer with g groups by definition conducts g convolution operations, each has weight of size (out//g, in//g, k, k). How to access those weights?

If you’ve created the conv layer as a module attribute (e.g. self.conv = nn.Conv2d(...)), you can access the weight matrix via:

model = MyModel()
kernels = model.conv.weight

in this case

kernels.size()

will return (out, in//g, k, k).

I would like to obtain g kernels, each one has size (out//g, in//g, k, k), corresponding to the g side by side conv layers

You could probably use torch.chunk to get the paired kernels:

groups = 6
conv = nn.Conv2d(6, 12, 3, 1, 1, groups=groups, bias=False)
kernels = torch.stack(conv.weight.chunk(groups, 0))

# Reference
x = torch.randn(1, 6, 24, 24)
out = conv(x)

# Comparison
o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
print(torch.allclose(o1, out[:, :2]))
2 Likes
import torch
from torch import nn
from torch.nn import functional as F
groups = 6
conv = nn.Conv2d(6, 12, 3, 1, 1, groups=groups, bias=False)
kernels = torch.stack(conv.weight.chunk(groups, 0))
# Reference
x = torch.randn(1, 6, 24, 24)
out = conv(x)
# Comparison
o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
print(torch.allclose(o1, out[:, :2]))
False

should i expect this False message?

No, as it returns True on my machine.
I’ve run the code 1000 times in a loop and don’t get a False value.

then i have no idea why…

Which PyTorch version are you using?

Thanks, i am currently using 1.1.0

That’s strange.
I’ve tested the code again on 1.1.0 as well as 1.1.0.dev20190502 and both seem to run without an issue:


import torch
from torch import nn
from torch.nn import functional as F

for _ in range(1000):
    groups = 6
    conv = nn.Conv2d(6, 12, 3, 1, 1, groups=groups, bias=False)
    kernels = torch.stack(conv.weight.chunk(groups, 0))
    # Reference
    x = torch.randn(1, 6, 24, 24)
    out = conv(x)
    # Comparison
    o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
    if not torch.allclose(o1, out[:, :2]):
        print('False')

Maybe we are unlucky and triggering this issue.
Could you change the number of input channels and run the test again?

import torch
from torch import nn
from torch.nn import functional as F
for _ in range(1000):
    groups = 12
    conv = nn.Conv2d(12, 12, 3, 1, 1, groups=groups, bias=False)
    kernels = torch.stack(conv.weight.chunk(groups, 0))
    # Reference
    x = torch.randn(1, 12, 24, 24)
    out = conv(x)
    # Comparison
    o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
    if not torch.allclose(o1, out[:, :2]):
        print('False')

still…

o1 in your example will have the shape [1, 1, 24, 24], while out[:, :2] will be [1, 2, 24, 24].
Could you change the condition to if not torch.allclose(o1, out[:, :1]) and run it again?