diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 849123a50..901930458 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1923,10 +1923,16 @@ def get_falcon_spec(self, model): def get_model_spec(self, model): self.get_falcon_spec(model) - if getattr(model.config, "multi_query", False): - num_heads_kv = 1 - else: + if getattr(model.config, "new_decoder_architecture", False) + and not getattr(model.config, "multi_query", False): num_heads_kv = self._num_heads_kv + else: + num_heads_kv = 1 + + shared_layer_norm = False + if model.config.parallel_attn and (not hasattr(model.config, 'num_ln_in_parallel_attn') or + model.config.num_ln_in_parallel_attn != 2): + shared_layer_norm = True spec = transformer_spec.TransformerDecoderModelSpec.from_config( self._num_layers, @@ -1939,7 +1945,7 @@ def get_model_spec(self, model): rotary_dim=0 if model.config.rotary else None, rotary_interleave=False, parallel_residual=model.config.parallel_attn, - shared_layer_norm=num_heads_kv == 1, + shared_layer_norm=shared_layer_norm, num_heads_kv=num_heads_kv, )