I am using a default device. Here is a code snippet:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch import Tensor
import math
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
class MyLinearFunc(Function):
@staticmethod
def forward(ctx, input, weight, weight_fb, bias=None):
ctx.save_for_backward(input, weight, weight_fb, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, weight_fb, bias = ctx.saved_tensors
grad_input = grad_weight = grad_weight_fb = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight_fb) # feedback weight
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[3]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_weight_fb, grad_bias
class MyLinear(nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super(MyLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
#self.weight_fb = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=False)
self.weight_fb =torch.Tensor(out_features, in_features) # feedback weight
self.register_buffer('feedback_weight', self.weight_fb)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight_fb, a=math.sqrt(5)) # feedback weight
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
return MyLinearFunc.apply(input, self.weight, self.weight_fb, self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
N, D_in, H, D_out = 64, 1000, 100, 10
model = nn.Sequential(
MyLinear(D_in, H),
MyLinear(H, D_out)
)
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 0.1
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(3):
print(t)
x = torch.randn(N, D_in).to(device=device)
y = torch.randn(N, D_out).to(device=device)
model.to(device=device)
# Forward pass: compute predicted y by passing x to the model.
y_pred = model(x)
# Compute and print loss.
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Thanks a lot!