why does this give an error?
A:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.embedding_to_learn = nn.Embedding(10, 5, requires_grad=True)
self.embedding_to_not_learn = nn.Embedding(10, 5, requires_grad=False)
If I want to not learn some embeddings, then I have to do,
B:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.embedding_to_learn = nn.Embedding(10, 5)
self.embedding_to_not_learn = nn.Embedding(10, 5)
model = Net()
model.embedding_to_not_learn.requires_grad_(requires_grad=False)
what is wrong with A?