class shared_feature():
def __init__(self,out , inp):
super(shared_feature, self).__init__()
self.weights =torch.nn.Parameter(data =torch.Tensor(out , inp), \
requires_grad=True)
self.weights.data.uniform_(-1.0, 1.0)
class xtype_feature(nn.Module):
def __init__(self,out , inp ):
super(xtype_feature, self).__init__()
self.feature =torch.nn.Parameter(data =torch.Tensor(out , inp), \
requires_grad=True)
self.weights.data.uniform_(-1.0, 1.0)
def forward(self, x):
#concatenation with weights of shared feature
y= torch.cat(x, weights_from_shared_feature)
return y