I am building a custom model that consist of multiple pre-trained models (e.g. BERT). I want to have a global PretrainedConfig
that control the setup these group of pre-trained models in a higher level.
Examples such as:
from transformers import PretrainedConfig, BertConfig, GPT2Config
class GlobalConfig(PretrainedConfig):
def __init__(self, model_a_cfg=None, model_b_cfg=None, *inputs, **kwargs):
super().__init__()
self.model_a_cfg = model_a_cfg
self.model_b_cfg = model_b_cfg
class Config1(BertConfig):
def __init__(self, *inputs, **kwargs):
super().__init__()
class Config2(GPT2Config):
def __init__(self, *inputs, **kwargs):
super().__init__()
This results in:
bert_config = Config1()
gpt_config = Config2()
GlobalConfig(bert_config, gpt_config)
# TypeError: Object of type Config1 is not JSON serializable
I expect the GlobalConfig
work out like this
from transformers import BertModel, GPT2Model
class GlobalModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.model_a = BertModel(config.model_a_cfg)
self.model_b = GPT2Model(config.model_b_cfg)
config = GlobalConfig(bert_config, gpt_config)
global_model = GlobalModel(config)
I’d like to avoid managing multiple configurations for a large combined model as it can be cumbersome. What would be the most effective solution to this issue?