Hello.
I have series of matrix multiplication in a for loop structure, I want to transform it to one “big” matrix to do all the multiplication together to better utilize the GPU.
Here is the current implementation:
The model input x, y
in shape of [batch_size, k, config.hidden_size]
.
For each category id [0, 1, 2, 3]
we compute:
- non-linear reps for
x
andy
→cat_x_reps
,cat_y_reps
- multiple
cat_x_reps
,cat_y_reps
with a bilinear matrix
cat_x_reps
xC_i
xcat_y_reps
import torch
from torch import nn
from transformers.activations import ACT2FN
class FullyConnectedLayer(Module):
def __init__(self, config, input_dim, output_dim, dropout_prob):
super(FullyConnectedLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.dropout_prob = dropout_prob
self.dense = Linear(self.input_dim, self.output_dim)
self.layer_norm = LayerNorm(self.output_dim, eps=config.layer_norm_eps)
self.activation_func = ACT2FN[config.hidden_act]
self.dropout = Dropout(self.dropout_prob)
def forward(self, inputs):
temp = inputs
temp = self.dense(temp)
temp = self.activation_func(temp)
temp = self.layer_norm(temp)
temp = self.dropout(temp)
return temp
class MyClass(nn.Module):
def __init__(self, config, args):
super().__init__(config)
self.ffnn_size = args.ffnn_size
self.num_heads = args.num_heads
self.x_cat_mlp = nn.ModuleList([FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, args.dropout_prob)
for _ in range(self.num_heads)])
self.y_cat_mlp = nn.ModuleList([FullyConnectedLayer(config, config.hidden_size, self.ffnn_size, args.dropout_prob)
for _ in range(self.num_heads)])
self.classifiers = nn.ModuleList([Linear(self.ffnn_size, self.ffnn_size)
for _ in range(self.num_heads)])
def forward(self, x_reps, y_reps):
# x, y -> [batch_size, k, config.hidden_size]
cat_logit_list = []
for cat_id in range(self.num_heads):
cat_x_reps = self.x_cat_mlp[cat_id](x)
cat_y_reps = self.y_cat_mlp[cat_id](y)
temp = self.classifiers[cat_id](cat_x_reps) # [batch_size, k, ffnn_size]
cat_logits = torch.matmul(temp, cat_y_reps.permute([0, 2, 1])). # [batch_size, k, k]
cat_logit_list.append(cat_logits)
return torch.stack(cat_logit_list, dim=1)
I was able to do it for the attribute self.x_cat_mlp
because all the operations on the same input (x, y)
.
I defined:
self.all_cats_size = self.ffnn_size * self.num_cats
self.x_cat_mlp = FullyConnectedLayer(config, config.hidden_size, self.all_cats_size, args.dropout_prob)
x_reps = self.x_cat_mlp(x) # [batch, k, ffnn * num_cats]
cat_x_reps = x_reps.view((batch_size, self.num_cats, k, self.ffnn_size)) # [batch, num_cats, k, ffnn]
I am struggling to do it for self.classifiers
. the input is different each time, i.e cat_x_reps
is different for each category