I have just been migrating some code to PyTorch 0.4.0.
For some background - My model outputs scores that I put into a list as part of the forward pass. I then need to torch.cat the list so I can later call .sum() and then .backward()
Since updating to 0.4.0 I am finding that the torch.cat causes:
“RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation”
I tested by taking a single score out of the list and doing .backward(), and this does work, so I am fairly certain that it is the torch.cat() operation that is triggering this inplace problem.
def forward(self, clause_dsets={}, ont_clause_dsets={}, aggregator='keep_all',
inference=False, verbose=False):
outputs = []
for clause_name in clause_dsets.keys():
clause = self.clauses[clause_name]
if verbose:
log.debug('===Training Clause: {}==='.format(clause.label))
# 1) Pass clause's input dictionary as input, which contains the data
# per clause literal, with the literal label as key.
# 2) Compute clause grounding by triggering underlying machinery.
# 3) Output is a batch column vector, via unsqueeze(1). Append to list
# of outputs
outputs.append(clause.compute(
inference, clause_dsets[clause_name]).unsqueeze(1))
# Ontology clauses
for clause_name in ont_clause_dsets.keys():
clause = self.clauses[clause_name]
if verbose:
log.debug('===Training Ontology Clause: {}==='.format(clause.label))
outputs.append(clause.compute(
inference, ont_clause_dsets[clause.label]).unsqueeze(1))
if verbose:
log.debug('===Aggregating Clauses with {}==='.format(aggregator))
# Concatenate clauses into big column vector. All clauses have same
# output expectations (high score), so no need to know which clause is
# which (unless we want to have that insight in investigation phase)
clauses_value_tensor = torch.cat(outputs, dim=0)
...
I can call backward on everything before the last line. The error only triggers after the cat operation.
I may have messed something up elsewhere during the migration, but given that the error only occurs after the cat op and beforehand I can do .backward(), meaning that there are no inplace operations prior, is something up?