Hi,
I need to implement a class A, and two children class B and C, and a grand-child D, The code is like this:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as modelzoo
import torch.distributed as dist
class A(nn.Module):
def __init__(self, ratio=4, *args, **kwargs):
super().__init__()
self.conv_base = nn.Conv2d(3, 3 * ratio, 3, 1, 1)
class B(A):
def __init__(self, b_args, **kwargs):
super().__init__(ratio=4)
self.conv1 = nn.Conv2d(4, 3, 1, 1, 0)
class C(A):
def __init__(self, c_args, **kwargs):
super().__init__(ratio=4)
self.conv2 = nn.Conv2d(4, 3, 1, 1, 0)
class D(B, C):
def __init__(self, b_args, c_args):
super().__init__(b_args=b_args, c_args=c_args)
self.conv3 = nn.Conv2d(4, 3, 1, 1, 0)
b_args = dict(a=1)
c_args = dict(b=2)
model = D(b_args, c_args)
And I observed this error message:
super().__init__(ratio=4)
TypeError: C.__init__() missing 1 required positional argument: 'c_args'
Would you tell me how could I make this work?