Here is my custom implementation of a CIFG in Pytorch :
import math
import torch
from torch import nn
class CIFGCell(nn.Module):
def __init__(self, inp_dim, hidden_dim):
super().__init__()
self.inp_dim = inp_dim
self.hidden_dim = hidden_dim
self.w_ih = nn.parameter.Parameter(torch.empty((3 * hidden_dim, inp_dim)))
self.w_hh = nn.parameter.Parameter(torch.empty((3 * hidden_dim, hidden_dim)))
self.b_ih = nn.parameter.Parameter(torch.empty(3 * hidden_dim))
self.b_hh = nn.parameter.Parameter(torch.empty(3 * hidden_dim))
self.reset_parameters()
def reset_parameters(self) -> None:
stdv = 1.0 / math.sqrt(self.hidden_dim) if self.hidden_dim > 0 else 0
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)
def forward(self, x, hidden=None):
if hidden is None:
batch_size = x.size(0)
hx, cx = torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device), torch.zeros(batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
else:
hx, cx = hidden
gates = torch.mm(x, self.w_ih.t()) + torch.mm(hx, self.w_hh.t()) + self.b_ih + self.b_hh
ingate, cellgate, outgate = gates.chunk(3, 1)
ingate = torch.sigmoid(ingate)
forgetgate = 1 - ingate
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
class CIFG(nn.Module):
def __init__(self, inp_dim, hidden_dim, n_layers=1):
super().__init__()
self.inp_dim = inp_dim
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.layers = nn.ModuleList([
CIFGCell(inp_dim, hidden_dim) if i == 0 else CIFGCell(hidden_dim, hidden_dim)
for i in range(n_layers)
])
def forward(self, x, hidden=None):
if hidden is None:
batch_size = x.size(1)
hx, cx = torch.zeros(self.n_layers, batch_size, self.hidden_dim, dtype=x.dtype, device=x.device), torch.zeros(self.n_layers, batch_size, self.hidden_dim, dtype=x.dtype, device=x.device)
else:
hx, cx = hidden
outputs = []
for ts in range(x.size(0)):
hx, cx = self.forward_timestep(x[ts], hx, cx)
outputs.append(hx[-1])
return torch.stack(outputs), (hx, cx)
def forward_timestep(self, x, hx, cx):
hx = hx.chunk(self.n_layers, 0)
cx = cx.chunk(self.n_layers, 0)
prev_output = x
states = []
outputs = []
for i, layer in enumerate(self.layers):
output, state = layer(prev_output, (hx[i].squeeze(0), cx[i].squeeze(0)))
outputs.append(output)
states.append(state)
prev_output = output
return torch.stack(outputs), torch.stack(states)