Hi, everyone
I am trying to train a model contained customized layer without learnable parameters and the model throws no grad_fn error.
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data
# custom layer
class Filter(nn.Module):
def __init__(self, filter_size=10, filter_step=1):
super(Filter, self).__init__()
self.filter_size = filter_size
self.filter_step = filter_step
def Output_filter(self, x):
frame_indx = 0
while 1:
if frame_indx + self.filter_size <= len(x):
if sum(x[frame_indx: frame_indx + self.filter_size]) >= self.filter_size / 2:
x[frame_indx: frame_indx + self.filter_size] = 1
else:
if sum(x[frame_indx:]) >= len(x[frame_indx:]) / 2:
x[frame_indx:] = 1
break
frame_indx = frame_indx + self.filter_step
return x
def forward(self, x):
x = F.relu(self.Output_filter(x))
return x
# model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(7,2)
self.filter = Filter(3,1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = torch.argmax(x, 1).float()
x = self.filter(x)
return x
if __name__ == "__main__":
net = Net()
optimizer = torch.optim.Adam(net.parameters())
train_x = torch.rand((10,7),dtype=torch.float)
train_y = torch.tensor([1,1,1,1,0,1,0,0,1,0], dtype=torch.float)
train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=10, shuffle=False, num_workers=2)
loss_func = nn.BCELoss()
net.train()
for step, (b_x, b_y) in enumerate(train_loader, 1):
pred_train = net(b_x)
loss_ = loss_func(pred_train, b_y)
optimizer.zero_grad()
loss_.backward()
optimizer.step()
My guess is the torch.argmax(x, 1) operation is NOT differentiable, so I append the requires_grad_(True), like torch.argmax(x, 1).float().requires_grad_(True), but this time the model throws ‘leaf variable has been moved into the graph interior’ error. In pycharm debug mode, x.is_leaf is true after torch.argmax operation.
My question is there any other differentiable function can be used to replace the torch.argmax or is there any other ways to make my code work
Thanks any help in advance