a convolution layer with g groups by definition conducts g convolution operations, each has weight of size (out//g, in//g, k, k). How to access those weights?
If you’ve created the conv layer as a module attribute (e.g. self.conv = nn.Conv2d(...)
), you can access the weight matrix via:
model = MyModel()
kernels = model.conv.weight
in this case
kernels.size()
will return (out, in//g, k, k).
I would like to obtain g kernels, each one has size (out//g, in//g, k, k), corresponding to the g side by side conv layers
You could probably use torch.chunk
to get the paired kernels:
groups = 6
conv = nn.Conv2d(6, 12, 3, 1, 1, groups=groups, bias=False)
kernels = torch.stack(conv.weight.chunk(groups, 0))
# Reference
x = torch.randn(1, 6, 24, 24)
out = conv(x)
# Comparison
o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
print(torch.allclose(o1, out[:, :2]))
import torch
from torch import nn
from torch.nn import functional as F
groups = 6
conv = nn.Conv2d(6, 12, 3, 1, 1, groups=groups, bias=False)
kernels = torch.stack(conv.weight.chunk(groups, 0))
# Reference
x = torch.randn(1, 6, 24, 24)
out = conv(x)
# Comparison
o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
print(torch.allclose(o1, out[:, :2]))
False
should i expect this False message?
No, as it returns True
on my machine.
I’ve run the code 1000 times in a loop and don’t get a False
value.
Which PyTorch version are you using?
That’s strange.
I’ve tested the code again on 1.1.0
as well as 1.1.0.dev20190502
and both seem to run without an issue:
import torch
from torch import nn
from torch.nn import functional as F
for _ in range(1000):
groups = 6
conv = nn.Conv2d(6, 12, 3, 1, 1, groups=groups, bias=False)
kernels = torch.stack(conv.weight.chunk(groups, 0))
# Reference
x = torch.randn(1, 6, 24, 24)
out = conv(x)
# Comparison
o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
if not torch.allclose(o1, out[:, :2]):
print('False')
Maybe we are unlucky and triggering this issue.
Could you change the number of input channels and run the test again?
import torch
from torch import nn
from torch.nn import functional as F
for _ in range(1000):
groups = 12
conv = nn.Conv2d(12, 12, 3, 1, 1, groups=groups, bias=False)
kernels = torch.stack(conv.weight.chunk(groups, 0))
# Reference
x = torch.randn(1, 12, 24, 24)
out = conv(x)
# Comparison
o1 = F.conv2d(x[:, :1], kernels[0], padding=1)
if not torch.allclose(o1, out[:, :2]):
print('False')
still…
o1
in your example will have the shape [1, 1, 24, 24]
, while out[:, :2]
will be [1, 2, 24, 24]
.
Could you change the condition to if not torch.allclose(o1, out[:, :1])
and run it again?