ok - I went through the documentation and tried to stuff. I’m still not sure why this new version works but the following works for me ( for anyone who runs into the same issue ):
import torch
from torch import nn
from torch.autograd import Function
from torch.optim import SGD
class BinaryActivation(Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.round()
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
class BinaryLayer(Function):
def forward(self, input):
return input.round()
def backward(self, grad_output):
return grad_output
class SkipRNN(nn.Module):
def __init__(self, c_in=10, c_hidden=10):
super(SkipRNN, self).__init__()
self.hidden_layer = nn.Linear(c_in, c_hidden)
self.gate = nn.Sequential(*[nn.Linear(c_hidden, 1), nn.Sigmoid()])
self.num_hidden = c_hidden
def forward(self, x):
'''x.shape = [batch, time_steps, feaures]'''
bn = BinaryActivation.apply
u_t = torch.zeros((x.size(0),1)).float()
s_t = torch.zeros((x.size(0), self.num_hidden)).float()
out = torch.zeros((x.size(0), x.size(1), self.num_hidden))
for t in range(x.size(1)):
u_t_bin = bn(u_t)
s_t = u_t_bin * self.hidden_layer(x[:, t, :]) + (1 - u_t_bin) * s_t
del_u_t = self.gate(s_t)
u_t = u_t_bin * del_u_t + (1 - u_t_bin) * (u_t + torch.min(del_u_t, 1 - u_t))
out[:, t, :] = s_t
return out
def basic_check():
learning_rate = .1
x = torch.rand((8, 5)).float()
y = torch.rand((8, 5)).float()
# Create random Tensors for weights.
w1 = torch.randn(5, 10, dtype=torch.float, requires_grad=True)
w2 = torch.randn(10, 5, dtype=torch.float, requires_grad=True)
for t in range(50):
# bn = BinaryActivation.apply
bn = BinaryLayer()
y_pred = bn(x.mm(w1)).mm(w2)
loss = (y_pred - y).pow(2).mean()
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
# Manually zero the gradients after updating weights
w1.grad.zero_()
w2.grad.zero_()
def skip_rnn_check():
learning_rate = .1
x = torch.rand((8, 20, 10)).float()
y = torch.rand((8, 20, 10)).float()
model = SkipRNN(10, 10)
optimizer = SGD(model.parameters(), lr=.1)
for t in range(50):
optimizer.zero_grad()
y_pred = model(x)
loss = (y_pred - y).pow(2).mean()
loss.backward()
optimizer.step()
hi = 5
if __name__ == '__main__':
basic_check()
skip_rnn_check()
hi = 5