I am interested in implementing a more flexible API for the Linear layer, where the inputs are only the output feature size. The input feature size is inferred from the size of input tensor. I have a minimal implementation:
from torch.nn import Module, Linear
class FlexibleLinear(Module):
def __init__(self, out_feats):
super(FlexibleLinear, self).__init__()
self.out_feats = out_feats
self.initialized = False
self.linear = None
def build(self, x):
if self.initialized:
return
in_feats = x.shape[1]
out_feats = self.out_feats
self.linear = Linear(in_feats, out_feats)
self.initialized = True
def forward(self, x):
self.build(x)
y = self.linear(x)
return y
I am wondering if there is any (better) way to do this, and/or if this can create any problem for a larger scale network.