Yes definitely. Gonna try and only copy in relevant stuff for brevity.
Definition of complex layers.
root/complex/complexLayers.py:
import torch
from torch.nn import Module, Parameter, init, Conv2d, Linear, BatchNorm1d, BatchNorm2d, LayerNorm, ConvTranspose2d
def apply_complex(fr, fi, input):
return (fr(input.real)-fi(input.imag)).type(torch.complex64) \
+ 1j*(fr(input.imag)+fi(input.real)).type(torch.complex64)
class ComplexLinear(Module):
def __init__(self, in_features, out_features, bias=True):
super(ComplexLinear, self).__init__()
self.fc_r = Linear(in_features, out_features, bias=bias)
self.fc_i = Linear(in_features, out_features, bias=bias)
def forward(self, input):
return apply_complex(self.fc_r, self.fc_i, input)
Here’s where our model lives. The model holds the encoder which holds a transformer block which holds a Multiheaded Self Attention model which holds a single Attention Head. The complex linear layers are in the Attention Head and the complex layers are imported from another folder in the root directory.
root/models/TransUNet.py
import torch
import torch.nn as nn
from complex.complexLayers import ComplexConv2d, ComplexLinear, ComplexReLU, ComplexBatchNorm1d, ComplexDropout, NaiveComplexLayerNorm
class model(nn.Module):
def __init__(self, params):
super(model, self).__init__()
self.encoder = Encoder(params)
self.reshaper = Reshape()
self.decoder = Decoder()
def forward(self, x):
# Converting x to complex
x = torch.view_as_complex(x)
x = x.permute(1, 0, 2, 3)
x = self.encoder(x)
x = self.reshaper(x)
x = self.decoder(x)
# Converting output to real
x = torch.view_as_real(x)
return x
class Encoder(nn.Module):
def __init__(self, params):
super(Encoder, self).__init__()
self.embedding = Embedding(params=params)
self.transformers = nn.Sequential(OrderedDict([("Block " + str(i), TransformerBlock(params)) for i in range(config.num_transformers)]))
def forward(self, x):
x = self.embedding(x)
x = self.transformers(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, params):
super(TransformerBlock, self).__init__()
self.attn_norm = NaiveComplexLayerNorm((params["num_patches"], config.encoding_size), eps=config.norm_eps)
self.attn = MSA()
self.ffn_norm = NaiveComplexLayerNorm((params["num_patches"], config.encoding_size), eps=config.norm_eps)
self.ffn = MLP()
def forward(self, x):
h = x
x = self.attn_norm(x)
x = self.attn(x)
x = x + h
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
return x
class MSA(nn.Module):
def __init__(self):
super(MSA, self).__init__()
self.heads = [AttentionHead(2, 1) for _ in range(config.num_heads)]
self.w = ComplexLinear(config.attention_size * config.num_heads, config.encoding_size, bias=config.attention_bias)
self.dropout = ComplexDropout(config.dropout_rate)
def forward(self, x):
all_head_size = x.shape[-1] * config.num_heads
multi_head_shape = list(x.shape)
multi_head_shape[-1] = all_head_size
multi_head = torch.zeros(multi_head_shape, dtype=torch.complex64)
for i, head in enumerate(self.heads):
multi_head[:, :, :, (i * config.attention_size):((i+1) * config.attention_size)] = head(x)
x = self.w(multi_head)
x = self.dropout(x)
return x
class AttentionHead(nn.Module):
def __init__(self, in_channels=2, out_channels=1):
super(AttentionHead, self).__init__()
#self.num_heads = config.num_heads
self.keys = ComplexLinear(config.encoding_size, config.attention_size, bias=config.attention_bias)
self.queries = ComplexLinear(config.encoding_size, config.attention_size, bias=config.attention_bias)
self.values = ComplexLinear(config.encoding_size, config.attention_size, bias=config.attention_bias)
self.complex_map = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x):
keys = self.keys(x)
queries = self.queries(x)
values = self.values(x)
scores = complex_matmul(queries, keys.transpose(-1, -2))
scores /= config.attention_size ** 0.5
scores = torch.view_as_real(scores)
scores = scores[:, 0, :, :]
scores = scores.permute(0, 3, 1, 2)
scores = self.complex_map(scores)
scores = nn.Softmax(dim=-1)(scores)
scores = self.dropout(scores)
scores = torch.complex(scores, torch.zeros_like(scores))
return complex_matmul(scores, values)
Thanks so much for your time man it helps a ton!!