diff --git a/.github/workflows/create-release.yml b/.github/workflows/create-release.yml index d6a0308..3f41784 100644 --- a/.github/workflows/create-release.yml +++ b/.github/workflows/create-release.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: home_assistant_version: ["2023.12.4", "2024.2.1"] - arch: [aarch64, amd64, i386] # armhf + arch: [aarch64, armhf, amd64, i386] suffix: [""] include: - home_assistant_version: "2024.2.1" diff --git a/README.md b/README.md index a69c1e1..cd06d7b 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ In order to facilitate running the project entirely on the system where Home Ass ## Version History | Version | Description | |---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| v0.2.15 | Fix startup error when using llama.cpp backend and add flash attention to llama.cpp backend | | v0.2.14 | Fix llama.cpp wheels + AVX detection | | v0.2.13 | Add support for Llama 3, build llama.cpp wheels that are compatible with non-AVX systems, fix an error with exposing script entities, fix multiple small Ollama backend issues, and add basic multi-language support | | v0.2.12 | Fix cover ICL examples, allow setting number of ICL examples, add min P and typical P sampler options, recommend models during setup, add JSON mode for Ollama backend, fix missing default options | diff --git a/custom_components/llama_conversation/agent.py b/custom_components/llama_conversation/agent.py index b07a701..fe2b8df 100644 --- a/custom_components/llama_conversation/agent.py +++ b/custom_components/llama_conversation/agent.py @@ -41,6 +41,7 @@ CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_PROMPT_TEMPLATE, + CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, @@ -75,6 +76,7 @@ DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_PROMPT_TEMPLATE, + DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES, @@ -548,6 +550,7 @@ def _load_model(self, entry: ConfigEntry) -> None: if not install_result == True: raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!") + validate_llama_cpp_python_installation() self.llama_cpp_module = importlib.import_module("llama_cpp") Llama = getattr(self.llama_cpp_module, "Llama") @@ -558,13 +561,15 @@ def _load_model(self, entry: ConfigEntry) -> None: self.loaded_model_settings[CONF_BATCH_SIZE] = entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) self.loaded_model_settings[CONF_THREAD_COUNT] = entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) + self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION) self.llm = Llama( model_path=self.model_path, n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]), n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]), n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]), - n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]) + n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]), + flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION], ) _LOGGER.debug("Model loaded") @@ -613,13 +618,15 @@ def _update_options(self): if self.loaded_model_settings[CONF_CONTEXT_LENGTH] != self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) or \ self.loaded_model_settings[CONF_BATCH_SIZE] != self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) or \ self.loaded_model_settings[CONF_THREAD_COUNT] != self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) or \ - self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT): + self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] != self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) or \ + self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] != self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION): _LOGGER.debug(f"Reloading model '{self.model_path}'...") self.loaded_model_settings[CONF_CONTEXT_LENGTH] = self.entry.options.get(CONF_CONTEXT_LENGTH, DEFAULT_CONTEXT_LENGTH) self.loaded_model_settings[CONF_BATCH_SIZE] = self.entry.options.get(CONF_BATCH_SIZE, DEFAULT_BATCH_SIZE) self.loaded_model_settings[CONF_THREAD_COUNT] = self.entry.options.get(CONF_THREAD_COUNT, DEFAULT_THREAD_COUNT) self.loaded_model_settings[CONF_BATCH_THREAD_COUNT] = self.entry.options.get(CONF_BATCH_THREAD_COUNT, DEFAULT_BATCH_THREAD_COUNT) + self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION] = self.entry.options.get(CONF_ENABLE_FLASH_ATTENTION, DEFAULT_ENABLE_FLASH_ATTENTION) Llama = getattr(self.llama_cpp_module, "Llama") self.llm = Llama( @@ -627,7 +634,8 @@ def _update_options(self): n_ctx=int(self.loaded_model_settings[CONF_CONTEXT_LENGTH]), n_batch=int(self.loaded_model_settings[CONF_BATCH_SIZE]), n_threads=int(self.loaded_model_settings[CONF_THREAD_COUNT]), - n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]) + n_threads_batch=int(self.loaded_model_settings[CONF_BATCH_THREAD_COUNT]), + flash_attn=self.loaded_model_settings[CONF_ENABLE_FLASH_ATTENTION], ) _LOGGER.debug("Model loaded") model_reloaded = True @@ -894,15 +902,17 @@ def _generate(self, conversation: dict) -> str: if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - result = requests.post( - f"{self.api_host}{endpoint}", - json=request_params, - timeout=timeout, - headers=headers, - ) - try: + result = requests.post( + f"{self.api_host}{endpoint}", + json=request_params, + timeout=timeout, + headers=headers, + ) + result.raise_for_status() + except requests.exceptions.Timeout: + return f"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities." except requests.RequestException as err: _LOGGER.debug(f"Err was: {err}") _LOGGER.debug(f"Request was: {request_params}") @@ -1141,15 +1151,17 @@ def _generate(self, conversation: dict) -> str: if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - result = requests.post( - f"{self.api_host}{endpoint}", - json=request_params, - timeout=timeout, - headers=headers, - ) - try: + result = requests.post( + f"{self.api_host}{endpoint}", + json=request_params, + timeout=timeout, + headers=headers, + ) + result.raise_for_status() + except requests.exceptions.Timeout: + return f"The generation request timed out! Please check your connection settings, increase the timeout in settings, or decrease the number of exposed entities." except requests.RequestException as err: _LOGGER.debug(f"Err was: {err}") _LOGGER.debug(f"Request was: {request_params}") diff --git a/custom_components/llama_conversation/config_flow.py b/custom_components/llama_conversation/config_flow.py index c75eb82..ed7e17f 100644 --- a/custom_components/llama_conversation/config_flow.py +++ b/custom_components/llama_conversation/config_flow.py @@ -54,6 +54,7 @@ CONF_DOWNLOADED_MODEL_QUANTIZATION, CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS, CONF_PROMPT_TEMPLATE, + CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, @@ -93,6 +94,7 @@ DEFAULT_BACKEND_TYPE, DEFAULT_DOWNLOADED_MODEL_QUANTIZATION, DEFAULT_PROMPT_TEMPLATE, + DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, @@ -811,6 +813,11 @@ def local_llama_config_option_schema(options: MappingProxyType[str, Any], backen description={"suggested_value": options.get(CONF_BATCH_THREAD_COUNT)}, default=DEFAULT_BATCH_THREAD_COUNT, ): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)), + vol.Required( + CONF_ENABLE_FLASH_ATTENTION, + description={"suggested_value": options.get(CONF_ENABLE_FLASH_ATTENTION)}, + default=DEFAULT_ENABLE_FLASH_ATTENTION, + ): BooleanSelector(BooleanSelectorConfig()), vol.Required( CONF_USE_GBNF_GRAMMAR, description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)}, diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index e235015..f6e8e47 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -120,6 +120,8 @@ "generation_prompt": "<|start_header_id|>assistant<|end_header_id|>\n\n" } } +CONF_ENABLE_FLASH_ATTENTION = "enable_flash_attention" +DEFAULT_ENABLE_FLASH_ATTENTION = False CONF_USE_GBNF_GRAMMAR = "gbnf_grammar" DEFAULT_USE_GBNF_GRAMMAR = False CONF_GBNF_GRAMMAR_FILE = "gbnf_grammar_file" @@ -178,6 +180,7 @@ CONF_TEMPERATURE: DEFAULT_TEMPERATURE, CONF_REQUEST_TIMEOUT: DEFAULT_REQUEST_TIMEOUT, CONF_PROMPT_TEMPLATE: DEFAULT_PROMPT_TEMPLATE, + CONF_ENABLE_FLASH_ATTENTION: DEFAULT_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR: DEFAULT_USE_GBNF_GRAMMAR, CONF_EXTRA_ATTRIBUTES_TO_EXPOSE: DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS: DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, @@ -271,5 +274,5 @@ } } -INTEGRATION_VERSION = "0.2.14" +INTEGRATION_VERSION = "0.2.15" EMBEDDED_LLAMA_CPP_PYTHON_VERSION = "0.2.69" \ No newline at end of file diff --git a/custom_components/llama_conversation/manifest.json b/custom_components/llama_conversation/manifest.json index 6ddf608..3079207 100644 --- a/custom_components/llama_conversation/manifest.json +++ b/custom_components/llama_conversation/manifest.json @@ -1,7 +1,7 @@ { "domain": "llama_conversation", "name": "LLaMA Conversation", - "version": "0.2.14", + "version": "0.2.15", "codeowners": ["@acon96"], "config_flow": true, "dependencies": ["conversation"], diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index adfb594..24afee4 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -63,6 +63,7 @@ "ollama_json_mode": "JSON Output Mode", "extra_attributes_to_expose": "Additional attribute to expose in the context", "allowed_service_call_arguments": "Arguments allowed to be pass to service calls", + "enable_flash_attention": "Enable Flash Attention", "gbnf_grammar": "Enable GBNF Grammar", "gbnf_grammar_file": "GBNF Grammar Filename", "openai_api_key": "API Key", @@ -115,6 +116,7 @@ "ollama_json_mode": "JSON Output Mode", "extra_attributes_to_expose": "Additional attribute to expose in the context", "allowed_service_call_arguments": "Arguments allowed to be pass to service calls", + "enable_flash_attention": "Enable Flash Attention", "gbnf_grammar": "Enable GBNF Grammar", "gbnf_grammar_file": "GBNF Grammar Filename", "openai_api_key": "API Key", diff --git a/custom_components/llama_conversation/utils.py b/custom_components/llama_conversation/utils.py index a736906..a42654e 100644 --- a/custom_components/llama_conversation/utils.py +++ b/custom_components/llama_conversation/utils.py @@ -3,6 +3,7 @@ import sys import platform import logging +import multiprocessing import voluptuous as vol import webcolors from importlib.metadata import version @@ -68,17 +69,23 @@ def download_model_from_hf(model_name: str, quantization_type: str, storage_fold ) def _load_extension(): - """This needs to be at the root file level because we are using the 'spawn' start method""" + """ + Makes sure it is possible to load llama-cpp-python without crashing Home Assistant. + This needs to be at the root file level because we are using the 'spawn' start method. + Also ignore ModuleNotFoundError because that just means it's not installed. Not that it will crash HA + """ import importlib - importlib.import_module("llama_cpp") + try: + importlib.import_module("llama_cpp") + except ModuleNotFoundError: + pass def validate_llama_cpp_python_installation(): """ Spawns another process and tries to import llama.cpp to avoid crashing the main process """ - import multiprocessing - multiprocessing.set_start_method('spawn') # required because of aio - process = multiprocessing.Process(target=_load_extension) + mp_ctx = multiprocessing.get_context('spawn') # required because of aio + process = mp_ctx.Process(target=_load_extension) process.start() process.join() diff --git a/tests/llama_conversation/test_agent.py b/tests/llama_conversation/test_agent.py index dd6e312..49df0ce 100644 --- a/tests/llama_conversation/test_agent.py +++ b/tests/llama_conversation/test_agent.py @@ -20,6 +20,7 @@ CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_PROMPT_TEMPLATE, + CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, @@ -55,6 +56,7 @@ DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_PROMPT_TEMPLATE, + DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES, @@ -208,6 +210,7 @@ async def test_local_llama_agent(local_llama_agent_fixture): n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE), n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT), n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT), + flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION) ) all_mocks["tokenize"].assert_called_once() @@ -231,6 +234,7 @@ async def test_local_llama_agent(local_llama_agent_fixture): local_llama_agent.entry.options[CONF_THREAD_COUNT] = 24 local_llama_agent.entry.options[CONF_BATCH_THREAD_COUNT] = 24 local_llama_agent.entry.options[CONF_TEMPERATURE] = 2.0 + local_llama_agent.entry.options[CONF_ENABLE_FLASH_ATTENTION] = True local_llama_agent.entry.options[CONF_TOP_K] = 20 local_llama_agent.entry.options[CONF_TOP_P] = 0.9 local_llama_agent.entry.options[CONF_MIN_P] = 0.2 @@ -244,6 +248,7 @@ async def test_local_llama_agent(local_llama_agent_fixture): n_batch=local_llama_agent.entry.options.get(CONF_BATCH_SIZE), n_threads=local_llama_agent.entry.options.get(CONF_THREAD_COUNT), n_threads_batch=local_llama_agent.entry.options.get(CONF_BATCH_THREAD_COUNT), + flash_attn=local_llama_agent.entry.options.get(CONF_ENABLE_FLASH_ATTENTION) ) # do another turn of the same conversation diff --git a/tests/llama_conversation/test_config_flow.py b/tests/llama_conversation/test_config_flow.py index a28a0db..13fb4b5 100644 --- a/tests/llama_conversation/test_config_flow.py +++ b/tests/llama_conversation/test_config_flow.py @@ -26,6 +26,7 @@ CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, CONF_PROMPT_TEMPLATE, + CONF_ENABLE_FLASH_ATTENTION, CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, @@ -67,6 +68,7 @@ DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_PROMPT_TEMPLATE, + DEFAULT_ENABLE_FLASH_ATTENTION, DEFAULT_USE_GBNF_GRAMMAR, DEFAULT_GBNF_GRAMMAR_FILE, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES, @@ -304,7 +306,7 @@ def test_validate_options_schema(): options_llama_hf = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_HF) assert set(options_llama_hf.keys()) == set(universal_options + [ CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters - CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific + CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific CONF_CONTEXT_LENGTH, # supports context length CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching @@ -313,7 +315,7 @@ def test_validate_options_schema(): options_llama_existing = local_llama_config_option_schema(None, BACKEND_TYPE_LLAMA_EXISTING) assert set(options_llama_existing.keys()) == set(universal_options + [ CONF_TOP_K, CONF_TEMPERATURE, CONF_TOP_P, CONF_MIN_P, CONF_TYPICAL_P, # supports all sampling parameters - CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, # llama.cpp specific + CONF_BATCH_SIZE, CONF_THREAD_COUNT, CONF_BATCH_THREAD_COUNT, CONF_ENABLE_FLASH_ATTENTION, # llama.cpp specific CONF_CONTEXT_LENGTH, # supports context length CONF_USE_GBNF_GRAMMAR, CONF_GBNF_GRAMMAR_FILE, # supports GBNF CONF_PROMPT_CACHING_ENABLED, CONF_PROMPT_CACHING_INTERVAL # supports prompt caching