Why is tensor.shape
/tensor.size()
not a tenosr? what is the need to also have a tensor.Size
?
When tracing tensor.shape[0]
is a tensor but when not tracing tensor.shape[0]
is an int. So I have to do check with isinstance
like this:
@torch.jit.script
def center_slice_helper(x, h_offset, w_offset, h_end, w_end):
return x[:, :, h_offset:h_end, w_offset:w_end]
class CenterCrop(nn.Module):
def __init__(self, crop_size):
"""Crop from the center of a 4d tensor
Input shape can be dynamic
:param crop_size: the center crop size
"""
super(CenterCrop, self).__init__()
self.crop_size = crop_size
def extra_repr(self):
"""Extra information
"""
return 'crop_size={}'.format(
self.crop_size
)
def forward(self, x):
h_offset = (x.shape[2] - self.crop_size) / 2
w_offset = (x.shape[3] - self.crop_size) / 2
if not isinstance(h_offset, torch.Tensor):
h_offset, w_offset = torch.tensor(h_offset), torch.tensor(w_offset)
h_end = h_offset + self.crop_size
w_end = w_offset + self.crop_size
return center_slice_helper(x, h_offset, w_offset, h_end, w_end)
If tensor.shape
was a tensor we could use it like other tensors in the network