Dear Friends,
Can you please suggest me modifications to convert the following 2D DWT convolution using pytorch library to 3D DWT convolution. Where w represents the db1 wavelet filters, wt function carries out 2D wavelet convolution on image vim. Reference GitHub repository is Wavelet-U-net-Dehazing/wavelet.py at master · dectrfov/Wavelet-U-net-Dehazing · GitHub

w=pywt.Wavelet('db1')
dec_hi = torch.Tensor(w.dec_hi[::-1])
dec_lo = torch.Tensor(w.dec_lo[::-1])
rec_hi = torch.Tensor(w.rec_hi)
rec_lo = torch.Tensor(w.rec_lo)
filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1)/2.0,
dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
def wt(vimg):
padded = vimg
res = torch.zeros(vimg.shape[0],4*vimg.shape[1],int(vimg.shape[2]/2),int(vimg.shape[3]/2))
res = res.cuda()
for i in range(padded.shape[1]):
res[:,4*i:4*i+4] = torch.nn.functional.conv2d(padded[:,i:i+1], Variable(filters[:,None].cuda(),requires_grad=True),stride=2)
res[:,4*i+1:4*i+4] = (res[:,4*i+1:4*i+4]+1)/2.0

everywhere where there are two things that appear, make three things appear. where 2D indexing happens, you’d want to do 3D. Where 3D indexing happens, you’d want to do 4D, etc.

this sounds like someone’s homework problem, and the code is fairly concise – so I’m inclined to leave these hints and let you figure it out

Thank you @smth for your valuable hints. I tried to use the following code:

res = torch.nn.functional.conv3d(padded, Variable(filters[None,:].cuda(),requires_grad=True),stride=(2,2,2)) #[:,None]

But I am getting the error message as

RuntimeError: expected stride to be a single integer value or a list of 2 values to match the convolution dimensions, but got stride=[2, 2, 2]
The res is a zero tensor of dimension [4099, 120, 12, 12], padded is the training image tensor of dimension [4099, 30, 25, 25], and filters are DWT coefficients of db1 wavelet with dimension [4, 2, 2].
I would like to add a note to your reply that this is a research work to classify remote sensing images using wavelet convolution.

The res is a zero tensor of dimension [4099, 120, 12, 12], padded is the training image tensor of dimension [4099, 30, 25, 25], and filters are DWT coefficients of db1 wavelet with dimension [4, 2, 2].

These shapes sounds wrong as conv3d layers would expect a 5D tensor (batched) or a 4D tensor (unbatched) in the shapes [batch_size, channels, depth, height, width] or [channels, depth, height, width], respectively.
Also the filters should have 5 dimensions in [out_channels, in_channels, depth, height, width].

But there are 4099 training image cubes with height = 25, width = 25 and depth = 30 (That indicate channels). The code suggested by you may need some modification. Please respond.

batch_size = 4099
channels = ?
depth, height, width = 30, 25, 25
input = torch.randn(batch_size, channels, depth, height, width, device=device)
The filter values of size 2,2 are duplicated to 4 and converted to (4,1,2,2) in case of 2D DWT
what would be the values for kernel when I change it to 3D filter.
out_channels = 120
kd, kh, kw = ?, 2, 2 # kernel size
weight = nn.Parameter(torch.randn(out_channels, channels, kd, kh, kw, device=device))
res = torch.nn.functional.conv3d(input, weight, stride=(2,2,2))
print(res.shape)