Hi,
Sorry for the delay.
I am still convinced that I can do it by modifying the raw c++ file. i will guide you through what I have done so you can fully understand and hopefully guide me through the right direction.
So I re-iterate my problem: I would like to change the internal structure of a GRU cell to implement what we call a recurrent attention unit (RAU) (here) as it can be seen in the above image. It is just a matter of adding those two gates (ReLU and Softmax).
-
I have downloaded pytorch from source by executing the following command: git clone --recursive https://github.com/pytorch/pytorch
→ cd pytorch
→ git submodule sync
→ git submodule update --init --recursive
. After this I executed: export _GLIBCXX_USE_CXX11_ABI=1
.
-
When the installation was finished, I went to pytorch/aten/src/ATen/native/RNN.cpp and I have changed the following structure to implement my RAU:
template <typename cell_params>
struct GRUCell : Cell<Tensor, cell_params> {
using hidden_type = Tensor;
hidden_type operator()(
const Tensor& input,
const hidden_type& hidden,
const cell_params& params,
bool pre_compute_input = false) const override {
if (input.is_cuda() || input.is_xpu() || input.is_privateuseone()) {
TORCH_CHECK(!pre_compute_input);
auto igates = params.matmul_ih(input);
auto hgates = params.matmul_hh(hidden);
auto result = at::_thnn_fused_gru_cell(
igates, hgates, hidden, params.b_ih(), params.b_hh());
// Slice off the workspace argument (it's needed only for AD).
return std::move(std::get<0>(result));
}
const auto chunked_igates = pre_compute_input
? input.unsafe_chunk(5, 1)
: params.linear_ih(input).unsafe_chunk(5, 1);
auto chunked_hgates = params.linear_hh(hidden).unsafe_chunk(5, 1);
const auto reset_gate =
chunked_hgates[0].add_(chunked_igates[0]).sigmoid_();
const auto input_gate =
chunked_hgates[1].add_(chunked_igates[1]).sigmoid_();
const auto new_gate =
chunked_igates[2].add(chunked_hgates[2].mul_(reset_gate)).tanh_();
const auto attention_gate_ReLU =
chunked_hgates[3].add(chunked_igates[3]).relu_();
const auto attention_gate_softmax =
at::softmax(chunked_hgates[4].add(chunked_igates[4]), /*dim=*/-1);
auto gru_normal = (hidden - new_gate).mul_(input_gate).add_(new_gate);
auto rau = gru_normal + attention_gate_ReLU.mul(attention_gate_softmax);
#warning "C Preprocessor got here!"
return rau;
}
};
I have also added the following dependencies for softmax:
include <ATen/ops/_log_softmax.h>
#include <ATen/ops/_log_softmax_backward_data_native.h>
#include <ATen/ops/_log_softmax_native.h>
#include <ATen/ops/_masked_softmax_backward_native.h>
#include <ATen/ops/_masked_softmax_native.h>
#include <ATen/ops/_softmax.h>
#include <ATen/ops/_softmax_backward_data_native.h>
#include <ATen/ops/_softmax_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/log_softmax.h>
#include <ATen/ops/log_softmax_native.h>
#include <ATen/ops/softmax.h>
#include <ATen/ops/softmax_native.h>
#include <ATen/ops/special_log_softmax_native.h>
#include <ATen/ops/special_softmax_native.h>
- I have created an virtual environment in which i will build the pytorch module:
python3 -m venv ./test2
(test2) $ pip install --verbose pytorch/.
When I do this everything compiles very well and the warning message within the GRUCell struct fires on meaning that it compiles the right RNN.cpp file and not other dependencies.
Within my test2 venv I have two files which have been created: one torch and the other one torch-2.4.0a0+git8aa08b8.dist-info
Now that everything is find I have checked that the torch module within my test2 venv was effectively communicating to my local source pytorch with my changes:
(test2) python3
>> import torch
>> print(torch.__version__)
2.4.0a0+git8aa08b8
it means that it is actually taking into account the local version of pytorch that I have downloaded.
Now let’s go in the python code side.
The goal is that when I call my nn.GRU that it takes into account my new GRU cell modified in my local pytorch version. Because the gates have changed from 3 to 5, I would need, within rnn.py to modified this line:
elif mode == 'GRU':
gate_size = 3 * hidden_size
to this line:
elif mode == 'GRU':
gate_size = 5 * hidden_size
However, it throws me an error.
My question is that i don’t know if I have missed any steps and I need help to understand what I am doing wrong.
cheers