Implement a large scale Linear layer or use parameter server instead?

I want to implement a model which has a very large Linear layer, because the amount of features is very big, say 2^32 features, actual input will be a sparse tensor.

Typically if the amount of features is not big, I can create an ordinary Linear layer, for example

self.lin = Linear(in_features, out_features)

But since in_features=2^32, the above Linear layer won’t work properly.
So I’m thinking about ideas like,

  1. Split the huge Linear into multiple small ones, e.g. each has 2^20 features. And I looked at torch.distributed.rpc, but doesn’t seem to be able to do it.
  2. Or use parameter server, but no idea how to turn the Linear layer into a parameter server.
    I didn’t find how to do the above 2 ideas, please give me some advice.

Regarding the parameter server idea, I guess parameter server and linear are quite different concepts as a PS is a training paradigm while a linear layer is a component of a model.

It looks like you have a use case to shard a linear layer across multiple processes/potentially multiple nodes. We are planning to build out sharding primitives within PyTorch, please see this RFC: [RFC] Model Sharding for distributed training · Issue #55207 · pytorch/pytorch · GitHub which would suit this use case once it is built out.

Curious, if the input is a sparse feature, would an nn.EmbeddingBag work better than a linear layer?

1 Like

Thanks for your answer Rohan!
Yes, surely EmbeddingBag also looks great, but meanwhile I want to see how to shard Linear layer into multiple processes / machines.
Will definitely check out the RFC.

BTW if it’s nn.EmbeddingBag (actually I don’t need the aggregation operation, so nn.Embedding would be better I think), I think we still have to think out sharding.
Because the EmbeddingBag size would still be huge (i.e. 2^32), might not fit into a single machine.