RuntimeError for chunk inplace operation, new with torch 1.7

With the update to torch 1.7 I now get the following error…
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [24, 129, 1536]]

The error is created by this line…
return self._in_proj(query).chunk(3, dim=-1)
and can be (temporarily) fixed with…
return self._in_proj(query).unsafe_chunk(3, dim=-1)

My understanding is unsafe_chunk will be removed in the future. Is there a “correct” fix for this? This isn’t code I wrote and I’m unsure the proper way to get the same behavior as the original.

Is this an appropriate replacement or is there something more elegant?

proj = self._in_proj(query)
sz   = proj.size()[2] // 3
return proj[:,:,:sz], proj[:,:,sz:2*sz], proj[:,:,2*sz:]