From e5c3b28b6ea6f2ea0b8407bc2e0b2d47a92d4caf Mon Sep 17 00:00:00 2001 From: Felix Stollenwerk Date: Mon, 16 Dec 2024 17:43:03 +0100 Subject: [PATCH] fix: checkpoint conversion to HF --- src/modalities/models/huggingface_adapters/hf_adapter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/huggingface_adapters/hf_adapter.py b/src/modalities/models/huggingface_adapters/hf_adapter.py index 09c75183..e8e5de6d 100644 --- a/src/modalities/models/huggingface_adapters/hf_adapter.py +++ b/src/modalities/models/huggingface_adapters/hf_adapter.py @@ -17,7 +17,7 @@ class HFModelAdapterConfig(PretrainedConfig): model_type = "modalities" - def __init__(self, **kwargs): + def __init__(self, config={}, **kwargs): """ Initializes an HFModelAdapterConfig object. @@ -28,6 +28,7 @@ def __init__(self, **kwargs): ConfigError: If the config is not passed in HFModelAdapterConfig. """ super().__init__(**kwargs) + self.config = config # self.config is added by the super class via kwargs if self.config is None: raise ConfigError("Config is not passed in HFModelAdapterConfig.") @@ -115,7 +116,7 @@ def forward( raise NotImplementedError model_input = {"input_ids": input_ids, "attention_mask": attention_mask} model_forward_output: dict[str, torch.Tensor] = self.model.forward(model_input) - if return_dict: + if not return_dict: return ModalitiesModelOutput(**model_forward_output) else: return model_forward_output[self.prediction_key]