Hi,
I’m trying to write a policy net for REINFORCE, but I’m running into a backward()
variable versioning issue. A (relatively) small reproducible snippet is:
import torch.nn.functional as F
import numpy as np
import random
import torch
K = 4
G = K**2 // 4 + 1
MAX_EPISODE_LEN = 2
NUM_EPISODES = 1
class PolicyNetA(torch.nn.Module):
def __init__(self, G: int, KS: int, cnt_iters: int):
super(PolicyNetA, self).__init__()
self.cnt_iters = cnt_iters
self.G = G
self.conv_z = torch.nn.ModuleList([torch.nn.Conv3d(in_channels=1, out_channels=G, kernel_size=(G, 1, 1)) for _ in range(cnt_iters)])
# this outputs a (G, G, G) tensor, same as input.
def forward(self, x: torch.tensor):
for i in range(self.cnt_iters):
x_unsq = x.unsqueeze(0).unsqueeze(0)
# casting reshape() instead of view() didn't help.
# F.relu has inplace set to False by default.
x = F.relu(self.conv_z[i](x_unsq).view(self.G, self.G, self.G))
return x
def rollout_episode(pnet, X: torch.tensor):
policy_proba_dists = []
for t in range(MAX_EPISODE_LEN):
policy_proba_dist = pnet(X)
policy_proba_dists.append(policy_proba_dist)
# X should be detached from the current computation graph for the next forward call.
X = torch.randn(G, G, G).round()
return policy_proba_dists
def main():
with torch.autograd.set_detect_anomaly(True):
pnet = PolicyNetA(G = G, KS = 3, cnt_iters = 2)
optimizer = torch.optim.Adam(params = pnet.parameters())
for cnt_episode in range(1, NUM_EPISODES + 1):
X = torch.randn(G, G, G).round()
policy_proba_dists = rollout_episode(pnet, X)
for t in range(MAX_EPISODE_LEN):
gain_log_policy_value = -0.1 * torch.log(policy_proba_dists[t][0, 2, 1] + 1e-10)
optimizer.zero_grad()
gain_log_policy_value.backward()
optimizer.step()
if __name__ == "__main__":
main()
The error is: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5, 1, 5, 1, 1]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient.
I guess the actual error occurs somewhere inside Conv3d
, as its final resulting shape is [1, 5, 1, 5, 5]
.
Casting reshape
instead of view after the convolution doesn’t help.
The error goes away if we only rollout for one cycle instead of two: MAX_EPISODE_LEN = 1
.
I may have used torch.nn.ModuleList
incorrectly, but I’m not sure what would the correct usage be. I tried to have two different Conv3d
instances, unwrapped by a ModuleList
, but the error persists.
Thank you for your help!