'Proxy' object does not support item assignment

class TemporalShift(nn.Module):
    def __init__(self, n_segment=3, n_div=8):
        super(TemporalShift, self).__init__()
        self.n_segment = n_segment
        self.fold_div = n_div

    def forward(self, x):
        x = self.shift(x, self.n_segment, fold_div=self.fold_div)
        return x

    @staticmethod
    def shift(x, n_segment, fold_div=3):
        nt, c, h, w = x.size()
        n_batch = nt // n_segment
        x = x.view(n_batch, n_segment, c, h, w)

        fold = c // fold_div

        #out = torch.zeros_like(x)
        out = x.clone().zero_()
        out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
        out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
        out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift

        return out.view(nt, c, h, w)

when using fx to tracing tsm, i get this error:

out[:, :-1, :fold] = x[:, 1:, :fold]
TypeError: ‘Proxy’ object does not support item assignment

Hi @keyky, this is a limitation of symbolic tracing with FX. Here is a workaround using torch.fx.wrap:

@torch.fx.wrap                                                           
def shift_left(out, x, fold):                                            
    out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left                   
                                                                         
@torch.fx.wrap                                                           
def shift_right(out, x, fold):                                           
    out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
                                                                         
@torch.fx.wrap                                                           
def not_shift(out, x, fold):                                             
    out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift               
                                                                         
class TemporalShift(nn.Module):                                          
    def __init__(self, n_segment=3, n_div=8):                            
        super(TemporalShift, self).__init__()                            
        self.n_segment = n_segment                                       
        self.fold_div = n_div                                            
                                                                         
    def forward(self, x):                                                
        x = self.shift(x, self.n_segment, fold_div=self.fold_div)        
        return x                                                         
                                                                         
    @staticmethod                                                        
    def shift(x, n_segment, fold_div=3):                                 
        nt, c, h, w = x.size()                                           
        n_batch = nt // n_segment                                        
        x = x.view(n_batch, n_segment, c, h, w)                          
                                                                         
        fold = c // fold_div                                             
                                                                         
        #out = torch.zeros_like(x)                                       
        out = x.clone().zero_()                                          
        shift_left(out, x, fold)                                         
        shift_right(out, x, fold)                                        
        not_shift(out, x, fold)                                          
                                                                         
        return out.view(nt, c, h, w)                                     
                                                                         
m = TemporalShift()                                                      
ms = torch.fx.symbolic_trace(m)