import time
import torch
import torch.nn as nn
import numpy as np
class A(nn.Module):
def __init__(self):
super(A, self).__init__()
self.last = nn.Sequential(
nn.Linear(11, 2),
nn.LeakyReLU(),
nn.Linear(2, 2),
nn.LeakyReLU(),
nn.Linear(2, 1)
)
def forward(self, x):
return self.last(x)
class B(nn.Module):
def __init__(self):
super(B, self).__init__()
self.last = nn.Sequential(
nn.Linear(11, 2),
nn.LeakyReLU(),
nn.Linear(2, 2),
nn.LeakyReLU(),
nn.Linear(2, 1),
)
def forward(self, x):
x = self.last(x)
return x
from torch.autograd import Function
class RevGradF(Function):
@staticmethod
def forward(ctx, input_, alpha_):
ctx.save_for_backward(input_, alpha_)
output = input_
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
grad_input = None
_, alpha_ = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -grad_output * alpha_
return grad_input, None
revgrad = RevGradF.apply
class RevGrad(nn.Module):
def __init__(self, alpha=1., *args, **kwargs):
"""
A gradient reversal layer.
This layer has no parameters, and simply reverses the gradient
in the backward pass.
"""
super().__init__(*args, **kwargs)
self._alpha = torch.tensor(alpha, requires_grad=False)
def forward(self, input_):
return revgrad(input_, self._alpha)
class TestModule():
def __init__(self):
self.a = A()
self.b = B()
self.revgrad = RevGrad()
self.b = nn.Sequential(self.b, self.revgrad)
#self.optimizer = torch.optim.Adam([{"params":self.a.parameters()}, {"params":self.b.parameters()}])
self.a_optimizer = torch.optim.Adam(self.a.parameters())
self.b_optimizer = torch.optim.Adam(self.b.parameters())
tm = TestModule()
for _ in range(10000):
input = np.array([1 for _ in range(11)])
input_tensor = torch.from_numpy(input).float()
result_a = tm.a(input_tensor)
result_b = tm.b(input_tensor)
loss_a = (result_a-1)**2
loss_b = (result_b-1)**2
loss = (loss_a - loss_b)**2
tm.a_optimizer.zero_grad()
tm.b_optimizer.zero_grad()
# for e in tm.a.parameters():
# e.requires_grad = True
# for e in tm.b.parameters():
# e.requires_grad = False
(loss).backward()
#a_loss.backward()
# print("a_backward")
# print(f"loss : {loss}")
# print(tm.a.last[0].weight.grad)
# print(tm.b[0].last[0].weight.grad)
# for e in tm.a.parameters():
# e.requires_grad = False
# for e in tm.b.parameters():
# e.requires_grad = True
#b_loss.backward()
# print("b_backward")
# print(tm.a.last[0].weight.grad)
# print(tm.b[0].last[0].weight.grad)
tm.a_optimizer.step()
tm.b_optimizer.step()
print(f"a : {result_a}")
print(f"b : {result_b}")
#time.sleep(5)
I heard from other places telling me to use a autograd.Function. Is this appropriate?
It seems to work.
I am wondering if it is the proper method in pytorch.