C++ Libtorch MHA issues

Hello, I tried to implement a transformer for benign and malicious sha256 classification using C++ Libtorch 2.6.0+cu126, and I’ve encountered some issues.

First Implementation:

#include <torch/torch.h>
#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <vector>

struct MultiHeadAttentionImpl : torch::nn::Module {
    MultiHeadAttentionImpl(int64_t embed_size, int64_t num_heads)
        : embed_size(embed_size), num_heads(num_heads), head_dim(embed_size / num_heads) {
        assert(embed_size % num_heads == 0);

        query = register_module("query", torch::nn::Linear(embed_size, embed_size));
        key = register_module("key", torch::nn::Linear(embed_size, embed_size));
        value = register_module("value", torch::nn::Linear(embed_size, embed_size));
        fc_out = register_module("fc_out", torch::nn::Linear(embed_size, embed_size));

        to(torch::kCUDA);
    }

    torch::Tensor forward(torch::Tensor query, torch::Tensor key, torch::Tensor value) {
        int64_t N = query.size(0);
        int64_t query_len = query.size(1);
        int64_t key_len = key.size(1);
        int64_t value_len = value.size(1);

        std::cout << "Query shape: " << query.sizes() << std::endl;
        std::cout << "Key shape: " << key.sizes() << std::endl;
        std::cout << "Value shape: " << value.sizes() << std::endl;

        torch::Tensor queries, keys, values;

        std::cout << "Query device: " << query.device() << "\n";
        std::cout << "Key device: " << key.device() << "\n";
        std::cout << "Value device: " << value.device() << "\n";

        // Ensure the tensors are of type Float
        query = query.to(torch::kFloat32);
        key = key.to(torch::kFloat32);
        value = value.to(torch::kFloat32);

        try
        {
            queries = this->query(query).view({ N, query_len, num_heads, head_dim }).transpose(1, 2).contiguous();
            keys = this->key(key).view({ N, key_len, num_heads, head_dim }).transpose(1, 2).contiguous();
            values = this->value(value).view({ N, value_len, num_heads, head_dim }).transpose(1, 2).contiguous();
        }
        catch (const c10::Error& e)
        {
            std::cerr << "Error: " << e.msg() << std::endl;
        }

        std::cout << "Shape of queries after linear transformation: " << queries.sizes() << std::endl;
        std::cout << "Shape of keys after linear transformation: " << keys.sizes() << std::endl;
        std::cout << "Shape of values after linear transformation: " << values.sizes() << std::endl;

        auto energy = torch::matmul(queries, keys.transpose(-2, -1)) / std::sqrt(head_dim);

        auto attention = torch::softmax(energy, -1);

        auto out = torch::matmul(attention, values).transpose(1, 2).contiguous();
        out = out.view({ N, query_len, embed_size });
        out = fc_out(out);

        return out;
    }

    int64_t embed_size;
    int64_t num_heads;
    int64_t head_dim;
    torch::nn::Linear query{ nullptr }, key{ nullptr }, value{ nullptr }, fc_out{ nullptr };
};
TORCH_MODULE(MultiHeadAttention);

struct TransformerBlockImpl : torch::nn::Module {
    TransformerBlockImpl(int64_t embed_size, int64_t num_heads, int64_t forward_expansion)
        : attention(embed_size, num_heads),
        norm1(torch::nn::LayerNormOptions({ embed_size })),
        norm2(torch::nn::LayerNormOptions({ embed_size })) {
        feed_forward = register_module("feed_forward", torch::nn::Sequential(
            torch::nn::Linear(embed_size, forward_expansion * embed_size),
            torch::nn::ReLU(),
            torch::nn::Linear(forward_expansion * embed_size, embed_size)
        ));
    }

    torch::Tensor forward(torch::Tensor x) {
        std::cout << "Input shape: " << x.sizes() << std::endl;

        auto attention_out = attention(x, x, x);
        std::cout << "Attention output shape: " << attention_out.sizes() << std::endl;

        x = norm1(x + attention_out);
        std::cout << "Shape after norm1: " << x.sizes() << std::endl;

        auto forward_out = feed_forward->forward(x);
        std::cout << "Feed forward output shape: " << forward_out.sizes() << std::endl;

        x = norm2(x + forward_out);
        std::cout << "Shape after norm2: " << x.sizes() << std::endl;

        return x;
    }

    MultiHeadAttention attention{ nullptr };
    torch::nn::LayerNorm norm1{ nullptr }, norm2{ nullptr };
    torch::nn::Sequential feed_forward{ nullptr };
};
TORCH_MODULE(TransformerBlock);

struct TransformerImpl : torch::nn::Module {
    TransformerImpl(int64_t embed_size, int64_t num_heads, int64_t num_layers, int64_t forward_expansion)
        : embed_size(embed_size),
        layers(torch::nn::Sequential()),
        fc_out(torch::nn::Linear(embed_size, 2)) { // Binary classification
        for (int64_t i = 0; i < num_layers; ++i) {
            layers->push_back(TransformerBlock(embed_size, num_heads, forward_expansion));
        }
        register_module("layers", layers);
        register_module("fc_out", fc_out);
    }

    torch::Tensor forward(torch::Tensor x) {
        std::cout << "Input shape: " << x.sizes() << std::endl;

        for (auto& layer : *layers) {
            x = layer.get<TransformerBlock>()->forward(x);
            std::cout << "Shape after layer: " << x.sizes() << std::endl;
        }

        x = fc_out(x.mean(1));
        std::cout << "Shape after fc_out: " << x.sizes() << std::endl;

        return x;
    }

    int64_t embed_size;
    torch::nn::Sequential layers;
    torch::nn::Linear fc_out;
};
TORCH_MODULE(Transformer);

std::vector<std::string> load_data(const std::string& file_path) {
    std::ifstream file(file_path);
    std::vector<std::string> data;
    std::string line;
    while (std::getline(file, line)) {
        data.emplace_back(line);
    }
    return data;
}

torch::Tensor sha256_to_tensor(const std::string& sha256) {
    std::unordered_map<char, int64_t> char_to_idx = {
        {'0', 0}, {'1', 1}, {'2', 2}, {'3', 3},
        {'4', 4}, {'5', 5}, {'6', 6}, {'7', 7},
        {'8', 8}, {'9', 9}, {'A', 10}, {'B', 11},
        {'C', 12}, {'D', 13}, {'E', 14}, {'F', 15},
        {'a', 10}, {'b', 11}, {'c', 12}, {'d', 13},
        {'e', 14}, {'f', 15}
    };

    std::string trimmed_sha256 = sha256;
    if (trimmed_sha256.size() == 65 && trimmed_sha256[64] == ' ') {
        trimmed_sha256 = trimmed_sha256.substr(0, 64);
    }

    // Ensure the SHA-256 hash string is exactly 64 characters long
    if (trimmed_sha256.size() != 64) {
        throw std::runtime_error("Invalid SHA-256 hash length: " + std::to_string(trimmed_sha256.size()));
    }

    torch::Tensor tensor = torch::zeros({ 64 }, torch::kLong);
    for (size_t i = 0; i < trimmed_sha256.size(); ++i) {
        if (char_to_idx.find(trimmed_sha256[i]) != char_to_idx.end()) {
            tensor[i] = char_to_idx[trimmed_sha256[i]];
        }
        else {
            std::cerr << "Invalid character in SHA-256 hash: " << trimmed_sha256[i] << std::endl;
            throw std::runtime_error("Invalid character in SHA-256 hash: " + std::string(1, trimmed_sha256[i]));
        }
    }
    return tensor;
}

torch::Tensor preprocess_data(const std::vector<std::string>& data) {
    std::vector<torch::Tensor> tensor_data;
    for (const auto& hash : data) {
        tensor_data.emplace_back(sha256_to_tensor(hash));
    }
    return torch::stack(tensor_data);
}

class CustomDataset : public torch::data::Dataset<CustomDataset> {
public:
    CustomDataset(torch::Tensor data, torch::Tensor labels)
        : data_(data), labels_(labels) {
    }

    torch::data::Example<> get(size_t index) override {
        return { data_[index], labels_[index] };
    }

    torch::optional<size_t> size() const override {
        return data_.size(0);
    }

private:
    torch::Tensor data_, labels_;
};

void train_model(Transformer& model, torch::Device device, const std::vector<std::string>& clean_data, const std::vector<std::string>& malicious_data) {
    auto clean_tensor = preprocess_data(clean_data).to(device);
    auto malicious_tensor = preprocess_data(malicious_data).to(device);

    auto clean_labels = torch::zeros({ clean_tensor.size(0) }, torch::kLong).to(device);
    auto malicious_labels = torch::ones({ malicious_tensor.size(0) }, torch::kLong).to(device);

    auto data = torch::cat({ clean_tensor, malicious_tensor }, 0);
    auto labels = torch::cat({ clean_labels, malicious_labels }, 0);

    auto dataset = CustomDataset(data, labels).map(torch::data::transforms::Stack<>());
    auto data_loader = torch::data::make_data_loader(dataset, /*batch_size=*/32);

    torch::optim::Adam optimizer(model->parameters(), torch::optim::AdamOptions(1e-3));

    float v_loss;

    model->train();


    for (size_t epoch = 0; epoch < 10; ++epoch) {
        for (auto& batch : *data_loader) {
            auto inputs = batch.data.to(device);
            auto targets = batch.target.to(device);

            optimizer.zero_grad();
            auto outputs = model->forward(inputs);
            auto loss = torch::nn::functional::cross_entropy(outputs, targets);
            v_loss = loss.item<float>();
            loss.backward();
            optimizer.step();
        }
        std::cout << "Epoch [" << epoch + 1 << "/10], Loss: " << v_loss << std::endl;

        // Save the model every 5 epochs
        if ((epoch + 1) % 5 == 0) {
            std::stringstream stream;
            torch::save(model, stream);
            std::string model_path = "KloDA-t-mha-epoch-" + std::to_string(epoch + 1) + ".klo";
            std::ofstream file(model_path, std::ios::binary);
            file << stream.rdbuf();
            file.close();
        }
    }
}

int64_t calculate_total_parameters(int64_t embed_size, int64_t num_heads, int64_t num_layers, int64_t forward_expansion) {
    // MultiHeadAttention parameters
    int64_t mha_params = 4 * embed_size * embed_size;

    // LayerNorm parameters
    int64_t layernorm_params = 2 * embed_size;

    // Feed Forward Network parameters
    int64_t ffn_params = embed_size * forward_expansion * embed_size + forward_expansion * embed_size * embed_size;

    // TransformerBlock parameters
    int64_t transformer_block_params = mha_params + 2 * layernorm_params + ffn_params;

    // Total parameters for all TransformerBlocks
    int64_t total_transformer_blocks_params = num_layers * transformer_block_params;

    // Final Linear layer parameters
    int64_t final_fc_params = embed_size * 2;

    // Total parameters
    int64_t total_params = total_transformer_blocks_params + final_fc_params;

    return total_params;
}

int main() {
    int64_t embed_size = 256;
    int64_t num_heads = 8;
    int64_t num_layers = 6;
    int64_t forward_expansion = 4;

    torch::Device device(torch::kCUDA);

    auto clean_data = load_data("Clean_SHA256.txt");
    auto malicious_data = load_data("KloDA-sha256-db.klo");

    int64_t total_params = calculate_total_parameters(embed_size, num_heads, num_layers, forward_expansion);
    std::cout << "Total number of parameters: " << total_params << std::endl;

    if (total_params > 500000000) {
        std::cerr << "The number of parameters exceeds 0.5 billion." << std::endl;
        return 1;
    }
    else {
        std::cout << "The number of parameters is within the limit of 0.5 billion." << std::endl;
    }

    Transformer model(embed_size, num_heads, num_layers, forward_expansion);
    model->to(device);

    train_model(model, device, clean_data, malicious_data);

    return 0;
}

Output:

Total number of parameters: 4725248
The number of parameters is within the limit of 0.5 billion.
Input shape: [32, 64]
Input shape: [32, 64]
Query shape: [32, 64]
Key shape: [32, 64]
Value shape: [32, 64]
Query device: cuda:0
Key device: cuda:0
Value device: cuda:0
Queries shape before view and transpose: Error: mat1 and mat2 shapes cannot be multiplied (32x64 and 256x256)
Queries shape after view and transpose: [0]
Keys shape after view and transpose: [0]
Values shape after view and transpose: [0]

Second Implementation:

#include <torch/torch.h>
#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <vector>

struct MultiHeadAttentionImpl : torch::nn::Module {
    MultiHeadAttentionImpl(int64_t embed_size, int64_t num_heads)
        : embed_size(embed_size), num_heads(num_heads), head_dim(embed_size / num_heads) {
        assert(embed_size % num_heads == 0);

        attention = register_module("attention", torch::nn::MultiheadAttention(
            torch::nn::MultiheadAttentionOptions(embed_size, num_heads).dropout(0.1)));

        fc_out = register_module("fc_out", torch::nn::Linear(embed_size, embed_size));

        to(torch::kCUDA);
    }

    torch::Tensor forward(torch::Tensor query, torch::Tensor key, torch::Tensor value) {
        int64_t N = query.size(0);
        int64_t query_len = query.size(1);

        std::cout << "Query shape: " << query.sizes() << std::endl;
        std::cout << "Key shape: " << key.sizes() << std::endl;
        std::cout << "Value shape: " << value.sizes() << std::endl;

        torch::Tensor attn_output, attn_output_weights;

        try
        {
            std::tie(attn_output, attn_output_weights) = attention(query, key, value);
        }
        catch (const torch::Error& e)
        {
            std::cout << e.msg() << "\n";
        }

        std::cout << "Attention output shape: " << attn_output.sizes() << std::endl;

        auto out = fc_out(attn_output);

        return out;
    }

    int64_t embed_size;
    int64_t num_heads;
    int64_t head_dim;
    torch::nn::MultiheadAttention attention{ nullptr };
    torch::nn::Linear fc_out{ nullptr };
};
TORCH_MODULE(MultiHeadAttention);

struct TransformerBlockImpl : torch::nn::Module {
    TransformerBlockImpl(int64_t embed_size, int64_t num_heads, int64_t forward_expansion)
        : attention(embed_size, num_heads),
        norm1(torch::nn::LayerNormOptions({ embed_size })),
        norm2(torch::nn::LayerNormOptions({ embed_size })) {
        feed_forward = register_module("feed_forward", torch::nn::Sequential(
            torch::nn::Linear(embed_size, forward_expansion * embed_size),
            torch::nn::ReLU(),
            torch::nn::Linear(forward_expansion * embed_size, embed_size)
        ));
    }

    torch::Tensor forward(torch::Tensor x) {
        std::cout << "Input shape: " << x.sizes() << std::endl;

        auto attention_out = attention(x, x, x);
        std::cout << "Attention output shape: " << attention_out.sizes() << std::endl;

        x = norm1(x + attention_out);
        std::cout << "Shape after norm1: " << x.sizes() << std::endl;

        auto forward_out = feed_forward->forward(x);
        std::cout << "Feed forward output shape: " << forward_out.sizes() << std::endl;

        x = norm2(x + forward_out);
        std::cout << "Shape after norm2: " << x.sizes() << std::endl;

        return x;
    }

    MultiHeadAttention attention{ nullptr };
    torch::nn::LayerNorm norm1{ nullptr }, norm2{ nullptr };
    torch::nn::Sequential feed_forward{ nullptr };
};
TORCH_MODULE(TransformerBlock);

struct TransformerImpl : torch::nn::Module {
    TransformerImpl(int64_t embed_size, int64_t num_heads, int64_t num_layers, int64_t forward_expansion)
        : embed_size(embed_size),
        layers(torch::nn::Sequential()),
        fc_out(torch::nn::Linear(embed_size, 2)) { // Binary classification
        for (int64_t i = 0; i < num_layers; ++i) {
            layers->push_back(TransformerBlock(embed_size, num_heads, forward_expansion));
        }
        register_module("layers", layers);
        register_module("fc_out", fc_out);
    }

    torch::Tensor forward(torch::Tensor x) {
        std::cout << "Input shape: " << x.sizes() << std::endl;

        for (auto& layer : *layers) {
            x = layer.get<TransformerBlock>()->forward(x);
            std::cout << "Shape after layer: " << x.sizes() << std::endl;
        }

        x = fc_out(x.mean(1));
        std::cout << "Shape after fc_out: " << x.sizes() << std::endl;

        return x;
    }

    int64_t embed_size;
    torch::nn::Sequential layers;
    torch::nn::Linear fc_out;
};
TORCH_MODULE(Transformer);

std::vector<std::string> load_data(const std::string& file_path) {
    std::ifstream file(file_path);
    std::vector<std::string> data;
    std::string line;
    while (std::getline(file, line)) {
        data.emplace_back(line);
    }
    return data;
}

torch::Tensor sha256_to_tensor(const std::string& sha256) {
    std::unordered_map<char, int64_t> char_to_idx = {
        {'0', 0}, {'1', 1}, {'2', 2}, {'3', 3},
        {'4', 4}, {'5', 5}, {'6', 6}, {'7', 7},
        {'8', 8}, {'9', 9}, {'A', 10}, {'B', 11},
        {'C', 12}, {'D', 13}, {'E', 14}, {'F', 15},
        {'a', 10}, {'b', 11}, {'c', 12}, {'d', 13},
        {'e', 14}, {'f', 15}
    };

    std::string trimmed_sha256 = sha256;
    if (trimmed_sha256.size() == 65 && trimmed_sha256[64] == ' ') {
        trimmed_sha256 = trimmed_sha256.substr(0, 64);
    }

    // Ensure the SHA-256 hash string is exactly 64 characters long
    if (trimmed_sha256.size() != 64) {
        throw std::runtime_error("Invalid SHA-256 hash length: " + std::to_string(trimmed_sha256.size()));
    }

    torch::Tensor tensor = torch::zeros({ 64 }, torch::kLong);
    for (size_t i = 0; i < trimmed_sha256.size(); ++i) {
        if (char_to_idx.find(trimmed_sha256[i]) != char_to_idx.end()) {
            tensor[i] = char_to_idx[trimmed_sha256[i]];
        }
        else {
            std::cerr << "Invalid character in SHA-256 hash: " << trimmed_sha256[i] << std::endl;
            throw std::runtime_error("Invalid character in SHA-256 hash: " + std::string(1, trimmed_sha256[i]));
        }
    }
    return tensor;
}

torch::Tensor preprocess_data(const std::vector<std::string>& data) {
    std::vector<torch::Tensor> tensor_data;
    for (const auto& hash : data) {
        tensor_data.emplace_back(sha256_to_tensor(hash));
    }
    return torch::stack(tensor_data);
}

class CustomDataset : public torch::data::Dataset<CustomDataset> {
public:
    CustomDataset(torch::Tensor data, torch::Tensor labels)
        : data_(data), labels_(labels) {
    }

    torch::data::Example<> get(size_t index) override {
        return { data_[index], labels_[index] };
    }

    torch::optional<size_t> size() const override {
        return data_.size(0);
    }

private:
    torch::Tensor data_, labels_;
};

void train_model(Transformer& model, torch::Device device, const std::vector<std::string>& clean_data, const std::vector<std::string>& malicious_data) {
    auto clean_tensor = preprocess_data(clean_data).to(device);
    auto malicious_tensor = preprocess_data(malicious_data).to(device);

    auto clean_labels = torch::zeros({ clean_tensor.size(0) }, torch::kLong).to(device);
    auto malicious_labels = torch::ones({ malicious_tensor.size(0) }, torch::kLong).to(device);

    auto data = torch::cat({ clean_tensor, malicious_tensor }, 0);
    auto labels = torch::cat({ clean_labels, malicious_labels }, 0);

    auto dataset = CustomDataset(data, labels).map(torch::data::transforms::Stack<>());
    auto data_loader = torch::data::make_data_loader(dataset, /*batch_size=*/32);

    torch::optim::Adam optimizer(model->parameters(), torch::optim::AdamOptions(1e-3));

    float v_loss;

    model->train();

    for (size_t epoch = 0; epoch < 10; ++epoch) {
        for (auto& batch : *data_loader) {
            auto inputs = batch.data.to(device);
            auto targets = batch.target.to(device);

            optimizer.zero_grad();
            auto outputs = model->forward(inputs);
            auto loss = torch::nn::functional::cross_entropy(outputs, targets);
            v_loss = loss.item<float>();
            loss.backward();
            optimizer.step();
        }
        std::cout << "Epoch [" << epoch + 1 << "/10], Loss: " << v_loss << std::endl;

        // Save the model every 5 epochs
        if ((epoch + 1) % 5 == 0) {
            std::stringstream stream;
            torch::save(model, stream);
            std::string model_path = "KloDA-t-mha-epoch-" + std::to_string(epoch + 1) + ".klo";
            std::ofstream file(model_path, std::ios::binary);
            file << stream.rdbuf();
            file.close();
        }
    }
}

int64_t calculate_total_parameters(int64_t embed_size, int64_t num_heads, int64_t num_layers, int64_t forward_expansion) {
    // MultiHeadAttention parameters
    int64_t mha_params = 4 * embed_size * embed_size;

    // LayerNorm parameters
    int64_t layernorm_params = 2 * embed_size;

    // Feed Forward Network parameters
    int64_t ffn_params = embed_size * forward_expansion * embed_size + forward_expansion * embed_size * embed_size;

    // TransformerBlock parameters
    int64_t transformer_block_params = mha_params + 2 * layernorm_params + ffn_params;

    // Total parameters for all TransformerBlocks
    int64_t total_transformer_blocks_params = num_layers * transformer_block_params;

    // Final Linear layer parameters
    int64_t final_fc_params = embed_size * 2;

    // Total parameters
    int64_t total_params = total_transformer_blocks_params + final_fc_params;

    return total_params;
}

int main() {
    int64_t embed_size = 256;
    int64_t num_heads = 8;
    int64_t num_layers = 6;
    int64_t forward_expansion = 4;

    torch::Device device(torch::kCUDA);

    auto clean_data = load_data("Clean_SHA256.txt");
    auto malicious_data = load_data("KloDA-sha256-db.klo");

    int64_t total_params = calculate_total_parameters(embed_size, num_heads, num_layers, forward_expansion);
    std::cout << "Total number of parameters: " << total_params << std::endl;

    if (total_params > 500000000) {
        std::cerr << "The number of parameters exceeds 0.5 billion." << std::endl;
        return 1;
    }
    else {
        std::cout << "The number of parameters is within the limit of 0.5 billion." << std::endl;
    }

    Transformer model(embed_size, num_heads, num_layers, forward_expansion);
    model->to(device);

    train_model(model, device, clean_data, malicious_data);

    return 0;
}

Output:

Total number of parameters: 4725248
The number of parameters is within the limit of 0.5 billion.
Input shape: [32, 64]
Input shape: [32, 64]
Query shape: [32, 64]
Key shape: [32, 64]
Value shape: [32, 64]
embed_dim == embed_dim_to_check INTERNAL ASSERT FAILED at "C:\\actions-runner\\_work\\pytorch\\pytorch\\pytorch\\torch\\csrc\\api\\include\\torch/nn/functional/activation.h":672, please report a bug to PyTorch.

I still can’t solve them after couple rounds of debugging, could someone help me find out the issues in my two implementations above?