class CWT(nn.Module):
def __init__(
self,
widths,
wavelet="ricker",
channels=1,
filter_len=2000,
):
"""PyTorch implementation of a continuous wavelet transform.
Args:
widths (iterable): The wavelet scales to use, e.g. np.arange(1, 33)
wavelet (str, optional): Name of wavelet. Either "ricker" or "morlet".
Defaults to "ricker".
channels (int, optional): Number of audio channels in the input. Defaults to 3.
filter_len (int, optional): Size of the wavelet filter bank. Set to
the number of samples but can be smaller to save memory. Defaults to 2000.
"""
super().__init__()
self.widths = widths
self.wavelet = getattr(self, wavelet)
self.filter_len = filter_len
self.channels = channels
self.wavelet_bank = self._build_wavelet_bank()
def ricker(self, points, a):
# https://github.com/scipy/scipy/blob/v1.7.1/scipy/signal/wavelets.py#L262-L306
A = 2 / (np.sqrt(3 * a) * (np.pi ** 0.25))
wsq = a ** 2
vec = torch.arange(0, points) - (points - 1.0) / 2
xsq = vec ** 2
mod = 1 - xsq / wsq
gauss = torch.exp(-xsq / (2 * wsq))
total = A * mod * gauss
return total
def morlet(self, points, s):
x = torch.arange(0, points) - (points - 1.0) / 2
x = x / s
# https://pywavelets.readthedocs.io/en/latest/ref/cwt.html#morlet-wavelet
wavelet = torch.exp(-(x ** 2.0) / 2.0) * torch.cos(5.0 * x)
output = np.sqrt(1 / s) * wavelet
return output
def cmorlet(self, points, s, wavelet_width=1, center_freq=1):
# https://pywavelets.readthedocs.io/en/latest/ref/cwt.html#complex-morlet-wavelets
x = torch.arange(0, points) - (points - 1.0) / 2
x = x / s
norm_constant = np.sqrt(np.pi * wavelet_width)
exp_term = torch.exp(-(x ** 2) / wavelet_width)
kernel_base = exp_term / norm_constant
kernel = kernel_base * torch.exp(1j * 2 * np.pi * center_freq * x)
return kernel
def _build_wavelet_bank(self):
"""This function builds a 2D wavelet filter using wavelets at different scales
Returns:
tensor: Tensor of shape (num_widths, 1, channels, filter_len)
"""
wavelet_bank = [
torch.conj(torch.flip(self.wavelet(self.filter_len, w), [-1]))
for w in self.widths
]
wavelet_bank = torch.stack(wavelet_bank)
wavelet_bank = wavelet_bank.view(
wavelet_bank.shape[0], 1, 1, wavelet_bank.shape[1]
)
wavelet_bank = torch.cat([wavelet_bank] * self.channels, 2)
return wavelet_bank
def forward(self, x):
"""Compute CWT arrays from a batch of multi-channel inputs
Args:
x (torch.tensor): Tensor of shape (batch_size, channels, time)
Returns:
torch.tensor: Tensor of shape (batch_size, channels, widths, time)
"""
x = x.unsqueeze(1)
if self.wavelet_bank.is_complex():
wavelet_real = self.wavelet_bank.real.to(device=x.device, dtype=x.dtype)
wavelet_imag = self.wavelet_bank.imag.to(device=x.device, dtype=x.dtype)
output_real = nn.functional.conv2d(x, wavelet_real, padding="same")
output_imag = nn.functional.conv2d(x, wavelet_imag, padding="same")
#This
output_real = torch.transpose(output_real, 1, 2)
output_imag = torch.transpose(output_imag, 1, 2)
return torch.complex(output_real, output_imag)
# return output_real, output_imag
else:
self.wavelet_bank = self.wavelet_bank.to(device=x.device, dtype=x.dtype)
output = nn.functional.conv2d(x, self.wavelet_bank, padding="same")
return torch.transpose(output, 1, 2)
==>
> *# pycwt = CWT(widths, "cmorlet", 3, 4096)*
>
> *# pycwt = CWT(widths, "ricker", 3)*
> *x = np.load(train.loc[i, 'file_path'])*
> *x *= signal.tukey(4096, 0.2)*
>
>
> *x = apply_bandpass(x, 30, 480)*
> *x_ten = torch.tensor(x, dtype=torch.float32).view(1, 3, 4096)*
>
>
> *widths = np.arange(start=10, stop=90)*
>
> *pycwt = CWT(widths, "cmorlet", 3, 4096)*
> *# pycwt = CWT(widths, "cmorlet", 3, 2048)*
> *#this==>>>>*
> *out = pycwt(x_ten)*
>
> *print(out.shape)*
> *mag = torch.absolute(out)*
>
> *plt.imshow(*
> * mag[0, 0].numpy(),*
> * aspect="auto",*
> * vmax=mag[0, 0].max(),*
> * vmin=mag[0, 0].min(),*
> *);*
```This is executing code.
Blockquote
Explanation:
It works fine with google colab, but
But it doesn't work with kaggle kernel.
Why is this? I'm having a hard time finding a good answer.
I don't understand why it doesn't work when I use tuples of numbers.