Class autograd.Function in module cause autoreload fail in jupyter lab

I define a foo.py file and then creat a jupyter lab notebook that will import the class CNN in this file.
In this foo.py file. I first define the class CNN then the class fun

import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,32,3,1,1) 
        print('pytorch')
        print('ok')
        
    
class fun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        return grad_input * temp.float()

In the jupyter lab, I enable the autoreload

%load_ext autoreload
%autoreload 2

And the autoreload magic will work. However, if I change the sequence of class defined in foo.py,
which means I first define class fun then define class CNN. The autoreload will not work in my jupyter lab notebook.

import torch
import torch.nn as nn
import torch.nn.functional as F

class fun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        return grad_input * temp.float()
    
    
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,32,3,1,1) 
        print('pytorch')
        print('ok')

In this case, If I modify the __init__ in class CNN, the autoreload didn’t work.
Anyone can answer why ?

Hi,

Quite weird indeed.
Does it only happen in you define a Function subclass? What if you just define any class:

class Dummy():
    pass

I found that:
if I define as class fun() instead of class fun(torch.autograd.Function):
the autoreload succeed.

Also I try following , the autoreload works for both sequences.

import torch
import torch.nn as nn
import torch.nn.functional as F

class fun(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(thresh).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - thresh) < lens
        return grad_input * temp.float()

class Dummy():
    print("torch")

`

Hi,

I don’t think pytorch supports reloading the module. So this is most likely going to fail in unexpacted ways if you try to reload(torch) no?

Really thanks for your help. I don’t understand, where to reload(torch)? and why need reload(torch), because I only modify the module that I write and the torch module is untouched.

I am not sure how autoreload works but would it be reloading torch when it reload your module?

Thank you, may be I should post this to ipython in github.
the Reload all modules (except those excluded by %aimport ) every time before executing the Python code typed.
This is what the official docs say

I know that using importlib.reload(torch) will not behave properly (and usually ends up segfaulting), we do have an issue open to track this out. But this will be quite hard to fix.

1 Like