From 7fe98c67a3da44ba1ca661f51d332465c4f77821 Mon Sep 17 00:00:00 2001 From: softmix Date: Fri, 29 Mar 2024 21:59:57 +0100 Subject: [PATCH] gracefully fail loading this would crash when models are restricted, instead log the error and allow using the other ones. --- utils.py | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/utils.py b/utils.py index 45f82cd..ef5252a 100644 --- a/utils.py +++ b/utils.py @@ -15,26 +15,29 @@ def load_models() -> Dict[str, Tuple[PreTrainedModel, PreTrainedTokenizer, Model models = {} for family in config.MODEL_FAMILIES.values(): for model_config in family: - backend_config = model_config.backend - - logger.info(f"Loading tokenizer for {backend_config.repository}") - tokenizer = AutoTokenizer.from_pretrained(backend_config.repository, add_bos_token=False, use_fast=False) - - logger.info( - f"Loading model {backend_config.repository} with adapter {backend_config.adapter} in {config.TORCH_DTYPE}" - ) - # We set use_fast=False since LlamaTokenizerFast takes a long time to init - model = AutoDistributedModelForCausalLM.from_pretrained( - backend_config.repository, - active_adapter=backend_config.adapter, - torch_dtype=config.TORCH_DTYPE, - initial_peers=config.INITIAL_PEERS, - max_retries=3, - ) - model = model.to(config.DEVICE) - - for key in [backend_config.key] + list(backend_config.aliases): - models[key] = model, tokenizer, backend_config + try: + backend_config = model_config.backend + + logger.info(f"Loading tokenizer for {backend_config.repository}") + tokenizer = AutoTokenizer.from_pretrained(backend_config.repository, add_bos_token=False, use_fast=False) + + logger.info( + f"Loading model {backend_config.repository} with adapter {backend_config.adapter} in {config.TORCH_DTYPE}" + ) + # We set use_fast=False since LlamaTokenizerFast takes a long time to init + model = AutoDistributedModelForCausalLM.from_pretrained( + backend_config.repository, + active_adapter=backend_config.adapter, + torch_dtype=config.TORCH_DTYPE, + initial_peers=config.INITIAL_PEERS, + max_retries=3, + ) + model = model.to(config.DEVICE) + + for key in [backend_config.key] + list(backend_config.aliases): + models[key] = model, tokenizer, backend_config + except Exception as e: + logger.error(f"Failed to load model {model_config.backend.repository} due to {e}") return models