2D DWT convolution to 3D DWT convolution

Dear Friends,
Can you please suggest me modifications to convert the following 2D DWT convolution using pytorch library to 3D DWT convolution. Where w represents the db1 wavelet filters, wt function carries out 2D wavelet convolution on image vim. Reference GitHub repository is Wavelet-U-net-Dehazing/wavelet.py at master · dectrfov/Wavelet-U-net-Dehazing · GitHub


w=pywt.Wavelet('db1')

dec_hi = torch.Tensor(w.dec_hi[::-1]) 
dec_lo = torch.Tensor(w.dec_lo[::-1])
rec_hi = torch.Tensor(w.rec_hi)
rec_lo = torch.Tensor(w.rec_lo)

filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1)/2.0,
                       dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
                       dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
                       dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
def wt(vimg):
    padded = vimg
    res = torch.zeros(vimg.shape[0],4*vimg.shape[1],int(vimg.shape[2]/2),int(vimg.shape[3]/2))
    res = res.cuda()
    for i in range(padded.shape[1]):
        res[:,4*i:4*i+4] = torch.nn.functional.conv2d(padded[:,i:i+1], Variable(filters[:,None].cuda(),requires_grad=True),stride=2)
        res[:,4*i+1:4*i+4] = (res[:,4*i+1:4*i+4]+1)/2.0

everywhere where there are two things that appear, make three things appear. where 2D indexing happens, you’d want to do 3D. Where 3D indexing happens, you’d want to do 4D, etc.

this sounds like someone’s homework problem, and the code is fairly concise – so I’m inclined to leave these hints and let you figure it out

Thank you @smth for your valuable hints. I tried to use the following code:

res = torch.nn.functional.conv3d(padded, Variable(filters[None,:].cuda(),requires_grad=True),stride=(2,2,2)) #[:,None]

But I am getting the error message as

RuntimeError: expected stride to be a single integer value or a list of 2 values to match the convolution dimensions, but got stride=[2, 2, 2]
The res is a zero tensor of dimension [4099, 120, 12, 12], padded is the training image tensor of dimension [4099, 30, 25, 25], and filters are DWT coefficients of db1 wavelet with dimension [4, 2, 2].
I would like to add a note to your reply that this is a research work to classify remote sensing images using wavelet convolution.

The res is a zero tensor of dimension [4099, 120, 12, 12], padded is the training image tensor of dimension [4099, 30, 25, 25], and filters are DWT coefficients of db1 wavelet with dimension [4, 2, 2].

These shapes sounds wrong as conv3d layers would expect a 5D tensor (batched) or a 4D tensor (unbatched) in the shapes [batch_size, channels, depth, height, width] or [channels, depth, height, width], respectively.
Also the filters should have 5 dimensions in [out_channels, in_channels, depth, height, width].

Thank you @ptrblck . Can you please help me by providing a code snippet for Conv3d and filters for my example. Thank you once again.

Sure, here is an example:

device = 'cpu'
batch_size = 4099
channels = 30
depth, height, width = 25, 25, 25
input = torch.randn(batch_size, channels, depth, height, width, device=device)

out_channels = 120
kd, kh, kw = 4, 2, 2 # kernel size
weight = nn.Parameter(torch.randn(out_channels, channels, kd, kh, kw, device=device))

res = torch.nn.functional.conv3d(input, weight, stride=(2,2,2))
print(res.shape)
# torch.Size([4099, 120, 11, 12, 12])

@ptrblck , you have mentioned the following,

depth, height, width = 25, 25, 25

But there are 4099 training image cubes with height = 25, width = 25 and depth = 30 (That indicate channels). The code suggested by you may need some modification. Please respond.

You can freely change the code to your needs. If the depth is wrong just set it to the right value.

batch_size = 4099
channels = ?
depth, height, width = 30, 25, 25
input = torch.randn(batch_size, channels, depth, height, width, device=device)
The filter values of size 2,2 are duplicated to 4 and converted to (4,1,2,2) in case of 2D DWT 
what would be the values for kernel when I change it to 3D filter.
out_channels = 120
kd, kh, kw = ?, 2, 2 # kernel size
weight = nn.Parameter(torch.randn(out_channels, channels, kd, kh, kw, device=device))

res = torch.nn.functional.conv3d(input, weight, stride=(2,2,2))
print(res.shape)

Can you please suggest me changes.