Thank you!
Also one more einsum for fourier transform is kind of not working.
class FourierMMLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.dft_mat_seq = torch.tensor(linalg.dft(512))
self.dft_mat_hidden = torch.tensor(linalg.dft(768))
def forward(self, hidden_states):
hidden_states_complex = hidden_states.type(torch.complex128)
#pre fourier torch.Size([2, 9, 768]) of hidden states
return torch.einsum(
"...ij,...jk,...ni->...nk",
hidden_states_complex,
self.dft_mat_hidden,
self.dft_mat_seq
).real.type(torch.float32)
Traceback (most recent call last):
File "inference.py", line 22, in <module>
obj1.forward(input_ids, token_type_ids)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 124, in forward
self.encoder(input_ids, type_ids)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 113, in forward
sequence_output = self.encoder(embedding_output)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 94, in forward
hidden_states = layer_module(hidden_states)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 80, in forward
fft_output = self.fft(hidden_states)
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 62, in forward
return torch.einsum(
File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 9, 768]->[2, 1, 1, 9, 768] [768, 768]->[1, 1, 768, 1, 768] [512, 512]->[1, 512, 1, 512, 1]
Can you please help?