This code was very helpful because I want to do same thing by using conv3d.
However when I actually ran a below program, I wonder this code may be not work in case of using conv3d.
import torch
import torch.nn as nn
import torch.nn.functional as F
#Setup
N, T, C, H, W = 10, 24, 3, 24, 24
x = torch.stack([torch.full(size=(C,T,H,W), fill_value=float(n)) for n in range(N)])
for x_elem in x:
print(x_elem.abs().max())
# Create filterset for each sample
OC = 15
weights = []
for _ in range(N):
weight = nn.Parameter(torch.full(size=(OC,C,5,5,5), fill_value=float(_)))
weights.append(weight)
print(weight.abs().max())
# Apply manually
outputs = []
for idx in range(N):
input = x[idx:idx+1]
weight = weights[idx]
output = F.conv3d(input, weight, stride=1, padding=2)
outputs.append(output)
outputs = torch.stack(outputs)
outputs = outputs.squeeze(1) # remove fake batch dimension
# outputs = torch.cat(outputs, dim=0)
print(outputs.shape)
for output in outputs:
print(output.abs().max())
# Use grouped approach
weights = torch.stack(weights)
weights = weights.view(-1, C, 5, 5, 5)
print(weights.shape)
# move batch dim into channels
x = x.view(1, -1, T, H, W)
print(x.shape)
# Apply grouped conv
outputs_grouped = F.conv3d(x, weights, stride=1, padding=2, groups=N)
outputs_grouped = outputs_grouped.view(10, 15, 24, 24, 24)
print(outputs_grouped.shape)
# Compare
print((outputs - outputs_grouped).abs().max())
for output_grouped in outputs_grouped:
print(outputs_grouped.abs().max())
outputs
tensor(0.)
tensor(1.)
tensor(2.)
tensor(3.)
tensor(4.)
tensor(5.)
tensor(6.)
tensor(7.)
tensor(8.)
tensor(9.)
tensor(0., grad_fn=<MaxBackward1>)
tensor(1., grad_fn=<MaxBackward1>)
tensor(2., grad_fn=<MaxBackward1>)
tensor(3., grad_fn=<MaxBackward1>)
tensor(4., grad_fn=<MaxBackward1>)
tensor(5., grad_fn=<MaxBackward1>)
tensor(6., grad_fn=<MaxBackward1>)
tensor(7., grad_fn=<MaxBackward1>)
tensor(8., grad_fn=<MaxBackward1>)
tensor(9., grad_fn=<MaxBackward1>)
torch.Size([10, 15, 24, 24, 24])
for output in outputs:
print(output.abs().max())
>tensor(0., grad_fn=<MaxBackward1>)
>tensor(375., grad_fn=<MaxBackward1>)
>tensor(1500., grad_fn=<MaxBackward1>)
>tensor(3375., grad_fn=<MaxBackward1>)
>tensor(6000., grad_fn=<MaxBackward1>)
>tensor(9375., grad_fn=<MaxBackward1>)
>tensor(13500., grad_fn=<MaxBackward1>)
>tensor(18375., grad_fn=<MaxBackward1>)
>tensor(24000., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
print(weights.shape)
>torch.Size([150, 3, 5, 5, 5])
print(x.shape)
>torch.Size([1, 30, 24, 24, 24])
print(outputs_grouped.shape)
>torch.Size([10, 15, 24, 24, 24])
print((outputs - outputs_grouped).abs().max())
>tensor(0., grad_fn=<MaxBackward1>)
for output_grouped in outputs_grouped:
print(outputs_grouped.abs().max())
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
>tensor(30375., grad_fn=<MaxBackward1>)
Displayed max values of outputs
in this code gradually increase with increasing torch.full
value.
On the other hand, all outputs_grouped
's max values are same.
I think it might be caused by outputs_grouped = outputs_grouped.view(10, 15, 24, 24, 24)
. But I couldnât this solution.
I would be grateful if you could tell me about your experience.