From b2736a96ae9cc0a5dcb6fa5c87d776a6f865db0c Mon Sep 17 00:00:00 2001 From: L1u <932044860@qq.com> Date: Thu, 24 Oct 2024 11:41:07 +0800 Subject: [PATCH 1/7] ollama support. --- graphrag/config/enums.py | 3 + graphrag/index/llm/load_llm.py | 110 +++- graphrag/llm/__init__.py | 21 + graphrag/llm/ollama/__init__.py | 29 ++ graphrag/llm/ollama/create_ollama_client.py | 30 ++ graphrag/llm/ollama/factories.py | 139 +++++ graphrag/llm/ollama/json_parsing_llm.py | 38 ++ graphrag/llm/ollama/ollama_chat_llm.py | 60 +++ graphrag/llm/ollama/ollama_completion_llm.py | 44 ++ graphrag/llm/ollama/ollama_configuration.py | 493 ++++++++++++++++++ graphrag/llm/ollama/ollama_embeddings_llm.py | 40 ++ graphrag/llm/ollama/types.py | 8 + graphrag/llm/openai/factories.py | 15 +- graphrag/llm/openai/json_parsing_llm.py | 2 +- graphrag/llm/openai/openai_chat_llm.py | 8 +- graphrag/llm/openai/openai_completion_llm.py | 2 +- graphrag/llm/openai/openai_configuration.py | 36 +- .../llm/openai/openai_token_replacing_llm.py | 2 +- graphrag/llm/types/llm_config.py | 4 + graphrag/llm/{openai => }/utils.py | 35 +- .../structured_search/global_search/search.py | 2 +- pyproject.toml | 1 + 22 files changed, 1071 insertions(+), 51 deletions(-) create mode 100644 graphrag/llm/ollama/__init__.py create mode 100644 graphrag/llm/ollama/create_ollama_client.py create mode 100644 graphrag/llm/ollama/factories.py create mode 100644 graphrag/llm/ollama/json_parsing_llm.py create mode 100644 graphrag/llm/ollama/ollama_chat_llm.py create mode 100644 graphrag/llm/ollama/ollama_completion_llm.py create mode 100644 graphrag/llm/ollama/ollama_configuration.py create mode 100644 graphrag/llm/ollama/ollama_embeddings_llm.py create mode 100644 graphrag/llm/ollama/types.py rename graphrag/llm/{openai => }/utils.py (84%) diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 8741cf74ae..4410b272c4 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -98,14 +98,17 @@ class LLMType(str, Enum): # Embeddings OpenAIEmbedding = "openai_embedding" AzureOpenAIEmbedding = "azure_openai_embedding" + OllamaEmbedding = "ollama_embedding" # Raw Completion OpenAI = "openai" AzureOpenAI = "azure_openai" + Ollama = "ollama" # Chat Completion OpenAIChat = "openai_chat" AzureOpenAIChat = "azure_openai_chat" + OllamaChat = "ollama_chat" # Debug StaticResponse = "static_response" diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index a7eda31a4e..f3d58aed2a 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -22,6 +22,12 @@ create_openai_completion_llm, create_openai_embedding_llm, create_tpm_rpm_limiters, + OllamaConfiguration, + create_ollama_client, + create_ollama_chat_llm, + create_ollama_embedding_llm, + create_ollama_completion_llm, + LLMConfig, ) if TYPE_CHECKING: @@ -46,7 +52,7 @@ def load_llm( ) -> CompletionLLM: """Load the LLM for the entity extraction chain.""" on_error = _create_error_handler(callbacks) - + print(llm_type.value) if llm_type in loaders: if chat_only and not loaders[llm_type]["chat"]: msg = f"LLM type {llm_type} does not support chat" @@ -182,6 +188,50 @@ def _load_azure_openai_embeddings_llm( return _load_openai_embeddings_llm(on_error, cache, config, True) +def _load_ollama_completion_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], +): + return _create_ollama_completion_llm( + OllamaConfiguration({ + **_get_base_config(config), + }), + on_error, + cache, + ) + + +def _load_ollama_chat_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], +): + return _create_ollama_chat_llm( + OllamaConfiguration({ + # Set default values + **_get_base_config(config), + }), + on_error, + cache, + ) + + +def _load_ollama_embeddings_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], +): + # TODO: Inject Cache + return _create_ollama_embeddings_llm( + OllamaConfiguration({ + **_get_base_config(config), + }), + on_error, + cache, + ) + + def _get_base_config(config: dict[str, Any]) -> dict[str, Any]: api_key = config.get("api_key") @@ -218,6 +268,10 @@ def _load_static_response( "load": _load_azure_openai_completion_llm, "chat": False, }, + LLMType.Ollama: { + "load": _load_ollama_completion_llm, + "chat": False, + }, LLMType.OpenAIChat: { "load": _load_openai_chat_llm, "chat": True, @@ -226,6 +280,10 @@ def _load_static_response( "load": _load_azure_openai_chat_llm, "chat": True, }, + LLMType.OllamaChat: { + "load": _load_ollama_chat_llm, + "chat": True, + }, LLMType.OpenAIEmbedding: { "load": _load_openai_embeddings_llm, "chat": False, @@ -234,6 +292,10 @@ def _load_static_response( "load": _load_azure_openai_embeddings_llm, "chat": False, }, + LLMType.OllamaEmbedding: { + "load": _load_ollama_embeddings_llm, + "chat": False, + }, LLMType.StaticResponse: { "load": _load_static_response, "chat": False, @@ -286,7 +348,49 @@ def _create_openai_embeddings_llm( ) -def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter: +def _create_ollama_chat_llm( + configuration: OllamaConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, +) -> CompletionLLM: + """Create an Ollama chat llm.""" + client = create_ollama_client(configuration=configuration) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_ollama_chat_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_ollama_completion_llm( + configuration: OllamaConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, +) -> CompletionLLM: + """Create an Ollama completion llm.""" + client = create_ollama_client(configuration=configuration) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_ollama_completion_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_ollama_embeddings_llm( + configuration: OllamaConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, +) -> EmbeddingLLM: + """Create an Ollama embeddings llm.""" + client = create_ollama_client(configuration=configuration) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_ollama_embedding_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_limiter(configuration: LLMConfig) -> LLMLimiter: limit_name = configuration.model or configuration.deployment_name or "default" if limit_name not in _rate_limiters: tpm = configuration.tokens_per_minute @@ -296,7 +400,7 @@ def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter: return _rate_limiters[limit_name] -def _create_semaphore(configuration: OpenAIConfiguration) -> asyncio.Semaphore | None: +def _create_semaphore(configuration: LLMConfig) -> asyncio.Semaphore | None: limit_name = configuration.model or configuration.deployment_name or "default" concurrency = configuration.concurrent_requests diff --git a/graphrag/llm/__init__.py b/graphrag/llm/__init__.py index 609be951b2..508600ee48 100644 --- a/graphrag/llm/__init__.py +++ b/graphrag/llm/__init__.py @@ -42,6 +42,17 @@ LLMOutput, OnCacheActionFn, ) +from .ollama import ( + OllamaChatLLM, + OllamaClientType, + OllamaCompletionLLM, + OllamaConfiguration, + OllamaEmbeddingsLLM, + create_ollama_chat_llm, + create_ollama_client, + create_ollama_completion_llm, + create_ollama_embedding_llm, +) __all__ = [ # LLM Types @@ -79,6 +90,12 @@ "OpenAIConfiguration", "OpenAIEmbeddingsLLM", "RateLimitingLLM", + # Ollama + "OllamaChatLLM", + "OllamaClientType", + "OllamaCompletionLLM", + "OllamaConfiguration", + "OllamaEmbeddingsLLM", # Errors "RetriesExhaustedError", "TpmRpmLLMLimiter", @@ -86,6 +103,10 @@ "create_openai_client", "create_openai_completion_llm", "create_openai_embedding_llm", + "create_ollama_chat_llm", + "create_ollama_client", + "create_ollama_completion_llm", + "create_ollama_embedding_llm", # Limiters "create_tpm_rpm_limiters", ] diff --git a/graphrag/llm/ollama/__init__.py b/graphrag/llm/ollama/__init__.py new file mode 100644 index 0000000000..adf27d50a5 --- /dev/null +++ b/graphrag/llm/ollama/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Ollama LLM implementations.""" + +from .create_ollama_client import create_ollama_client +from .factories import ( + create_ollama_chat_llm, + create_ollama_completion_llm, + create_ollama_embedding_llm, +) +from .ollama_chat_llm import OllamaChatLLM +from .ollama_completion_llm import OllamaCompletionLLM +from .ollama_configuration import OllamaConfiguration +from .ollama_embeddings_llm import OllamaEmbeddingsLLM +from .types import OllamaClientType + + +__all__ = [ + "OllamaChatLLM", + "OllamaClientType", + "OllamaCompletionLLM", + "OllamaConfiguration", + "OllamaEmbeddingsLLM", + "create_ollama_chat_llm", + "create_ollama_client", + "create_ollama_completion_llm", + "create_ollama_embedding_llm", +] diff --git a/graphrag/llm/ollama/create_ollama_client.py b/graphrag/llm/ollama/create_ollama_client.py new file mode 100644 index 0000000000..87f5902c44 --- /dev/null +++ b/graphrag/llm/ollama/create_ollama_client.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Create OpenAI client instance.""" + +import logging +from functools import cache + +from ollama import AsyncClient + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + +log = logging.getLogger(__name__) + +API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client" + + +@cache +def create_ollama_client( + configuration: OllamaConfiguration +) -> OllamaClientType: + """Create a new Ollama client instance.""" + + log.info("Creating OpenAI client base_url=%s", configuration.api_base) + return AsyncClient( + host=configuration.api_base, + # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here + timeout=configuration.request_timeout or 180.0, + ) diff --git a/graphrag/llm/ollama/factories.py b/graphrag/llm/ollama/factories.py new file mode 100644 index 0000000000..e203cf48a1 --- /dev/null +++ b/graphrag/llm/ollama/factories.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Factory functions for creating OpenAI LLMs.""" + +import asyncio + +from graphrag.llm.base import CachingLLM, RateLimitingLLM +from graphrag.llm.limiting import LLMLimiter +from graphrag.llm.types import ( + LLM, + CompletionLLM, + EmbeddingLLM, + ErrorHandlerFn, + LLMCache, + LLMInvocationFn, + OnCacheActionFn, +) +from graphrag.llm.utils import ( + RATE_LIMIT_ERRORS, + RETRYABLE_ERRORS, + get_sleep_time_from_error, + get_token_counter, +) +from graphrag.llm.openai.openai_history_tracking_llm import OpenAIHistoryTrackingLLM +from graphrag.llm.openai.openai_token_replacing_llm import OpenAITokenReplacingLLM + +from .json_parsing_llm import JsonParsingLLM +from .ollama_chat_llm import OllamaChatLLM +from .ollama_completion_llm import OllamaCompletionLLM +from .ollama_configuration import OllamaConfiguration +from .ollama_embeddings_llm import OllamaEmbeddingsLLM +from .types import OllamaClientType + + +def create_ollama_chat_llm( + client: OllamaClientType, + config: OllamaConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI chat LLM.""" + operation = "chat" + result = OllamaChatLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + result = OpenAIHistoryTrackingLLM(result) + result = OpenAITokenReplacingLLM(result) + return JsonParsingLLM(result) + + +def create_ollama_completion_llm( + client: OllamaClientType, + config: OllamaConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI completion LLM.""" + operation = "completion" + result = OllamaCompletionLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return OpenAITokenReplacingLLM(result) + + +def create_ollama_embedding_llm( + client: OllamaClientType, + config: OllamaConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> EmbeddingLLM: + """Create an OpenAI embeddings LLM.""" + operation = "embedding" + result = OllamaEmbeddingsLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return result + + +def _rate_limited( + delegate: LLM, + config: OllamaConfiguration, + operation: str, + limiter: LLMLimiter | None, + semaphore: asyncio.Semaphore | None, + on_invoke: LLMInvocationFn | None, +): + result = RateLimitingLLM( + delegate, + config, + operation, + RETRYABLE_ERRORS, + RATE_LIMIT_ERRORS, + limiter, + semaphore, + get_token_counter(config), + get_sleep_time_from_error, + ) + result.on_invoke(on_invoke) + return result + + +def _cached( + delegate: LLM, + config: OllamaConfiguration, + operation: str, + cache: LLMCache, + on_cache_hit: OnCacheActionFn | None, + on_cache_miss: OnCacheActionFn | None, +): + cache_args = config.get_completion_cache_args() + result = CachingLLM(delegate, cache_args, operation, cache) + result.on_cache_hit(on_cache_hit) + result.on_cache_miss(on_cache_miss) + return result diff --git a/graphrag/llm/ollama/json_parsing_llm.py b/graphrag/llm/ollama/json_parsing_llm.py new file mode 100644 index 0000000000..588a0480c1 --- /dev/null +++ b/graphrag/llm/ollama/json_parsing_llm.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""An LLM that unpacks cached JSON responses.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from graphrag.llm.utils import try_parse_json_object + + +class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM with the input and kwargs.""" + result = await self._delegate(input, **kwargs) + if kwargs.get("json") and result.json is None and result.output is not None: + _, parsed_json = try_parse_json_object(result.output) + result.json = parsed_json + return result diff --git a/graphrag/llm/ollama/ollama_chat_llm.py b/graphrag/llm/ollama/ollama_chat_llm.py new file mode 100644 index 0000000000..a9ba70596b --- /dev/null +++ b/graphrag/llm/ollama/ollama_chat_llm.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + +log = logging.getLogger(__name__) + +_MAX_GENERATION_RETRIES = 3 +FAILED_TO_CREATE_JSON_ERROR = "Failed to generate valid JSON output" + + +class OllamaChatLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A Chat-based LLM.""" + + _client: OllamaClientType + _configuration: OllamaConfiguration + + def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> CompletionOutput | None: + args = { + **self.configuration.get_chat_cache_args(), + **(kwargs.get("model_parameters") or {}), + } + history = kwargs.get("history") or [] + messages = [ + *history, + {"role": "user", "content": input}, + ] + completion = await self.client.chat( + messages=messages, **args + ) + return completion["message"]["content"] + + async def _invoke_json( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Generate JSON output.""" + pass diff --git a/graphrag/llm/ollama/ollama_completion_llm.py b/graphrag/llm/ollama/ollama_completion_llm.py new file mode 100644 index 0000000000..4102418def --- /dev/null +++ b/graphrag/llm/ollama/ollama_completion_llm.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A text-completion based LLM.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, +) +from graphrag.llm.utils import get_completion_llm_args + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + + +log = logging.getLogger(__name__) + + +class OllamaCompletionLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A text-completion based LLM.""" + + _client: OllamaClientType + _configuration: OllamaConfiguration + + def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput | None: + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) + completion = await self.client.generate(prompt=input, **args) + return completion["response"] diff --git a/graphrag/llm/ollama/ollama_configuration.py b/graphrag/llm/ollama/ollama_configuration.py new file mode 100644 index 0000000000..468b3c5b33 --- /dev/null +++ b/graphrag/llm/ollama/ollama_configuration.py @@ -0,0 +1,493 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Ollama Configuration class definition.""" +import json +from collections.abc import Hashable +from typing import cast, Any + +from graphrag.llm import LLMConfig +from graphrag.llm.utils import non_blank, non_none_value_key + + +class OllamaConfiguration(Hashable, LLMConfig): + """OpenAI Configuration class definition.""" + + # Core Configuration + _api_key: str + _model: str + + _api_base: str | None + _api_version: str | None + _organization: str | None + + # Operation Configuration + _n: int | None + _temperature: float | None + _top_p: float | None + _format: str | None + _stop: str | None + _mirostat: int | None + _mirostat_eta: float | None + _mirostat_tau: float | None + _num_ctx: int | None + _repeat_last_n: int | None + _repeat_penalty: float | None + _frequency_penalty: float | None + _seed: int | None + _tfs_z: float | None + _num_predict: int | None + _top_k: int | None + _min_p: float | None + _options: dict | None + _suffix: str | None + _system: str | None + _template: str | None + _raw: bool | None + _keep_alive: int | None + _stream: bool | None + + + # Retry Logic + _max_retries: int | None + _max_retry_wait: float | None + _request_timeout: float | None + + # The raw configuration object + _raw_config: dict + + # Feature Flags + _model_supports_json: bool | None + + # Custom Configuration + _tokens_per_minute: int | None + _requests_per_minute: int | None + _concurrent_requests: int | None + _encoding_model: str | None + _sleep_on_rate_limit_recommendation: bool | None + + def __init__( + self, + config: dict, + ): + """Init method definition.""" + + def lookup_required(key: str) -> str: + return cast(str, config.get(key)) + + def lookup_str(key: str) -> str | None: + return cast(str | None, config.get(key)) + + def lookup_int(key: str) -> int | None: + result = config.get(key) + if result is None: + return None + return int(cast(int, result)) + + def lookup_float(key: str) -> float | None: + result = config.get(key) + if result is None: + return None + return float(cast(float, result)) + + def lookup_dict(key: str) -> dict | None: + return cast(dict | None, config.get(key)) + + def lookup_list(key: str) -> list | None: + return cast(list | None, config.get(key)) + + def lookup_bool(key: str) -> bool | None: + value = config.get(key) + if isinstance(value, str): + return value.upper() == "TRUE" + if isinstance(value, int): + return value > 0 + return cast(bool | None, config.get(key)) + + self._api_key = lookup_required("api_key") + self._model = lookup_required("model") + self._api_base = lookup_str("api_base") + self._api_version = lookup_str("api_version") + self._organization = lookup_str("organization") + self._n = lookup_int("n") + self._temperature = lookup_float("temperature") + self._top_p = lookup_float("top_p") + self._stop = lookup_str("stop") + self._mirostat = lookup_int("mirostat") + self._mirostat_eta = lookup_float("mirostat_eta") + self._mirostat_tau = lookup_float("mirostat_tau") + self._num_ctx = lookup_int("num_ctx") + self._repeat_last_n = lookup_int("repeat_last_n") + self._repeat_penalty = lookup_float("repeat_penalty") + self._frequency_penalty = lookup_float("frequency_penalty") + self._seed = lookup_int("seed") + self._tfs_z = lookup_float("tfs_z") + self._num_predict = lookup_int("num_predict") + self._top_k = lookup_int("top_k") + self._min_p = lookup_float("min_p") + self._suffix = lookup_str("suffix") + self._system = lookup_str("system") + self._template = lookup_str("template") + self._raw = lookup_bool("raw") + self._keep_alive = lookup_int("keep_alive") + self._stream = lookup_bool("stream") + self._format = lookup_str("response_format") + self._max_retries = lookup_int("max_retries") + self._request_timeout = lookup_float("request_timeout") + self._model_supports_json = lookup_bool("model_supports_json") + self._tokens_per_minute = lookup_int("tokens_per_minute") + self._requests_per_minute = lookup_int("requests_per_minute") + self._concurrent_requests = lookup_int("concurrent_requests") + self._encoding_model = lookup_str("encoding_model") + self._max_retry_wait = lookup_float("max_retry_wait") + self._sleep_on_rate_limit_recommendation = lookup_bool( + "sleep_on_rate_limit_recommendation" + ) + self._raw_config = config + self._options = { + "n": self._n, + "temperature": self._temperature, + "top_p": self._top_p, + "format": self._format, + "stop": self._stop, + "mirostat": self._mirostat, + "mirostat_eta": self._mirostat_eta, + "mirostat_tau": self._mirostat_tau, + "num_ctx": self._num_ctx, + "repeat_last_n": self._repeat_last_n, + "repeat_penalty": self._repeat_penalty, + "frequency_penalty": self._frequency_penalty, + "seed": self._seed, + "tfs_z": self._tfs_z, + "num_predict": self._num_predict, + "top_k": self._top_k, + "min_p": self._min_p, + "suffix": self._suffix, + "system": self._system, + "template": self._template, + "raw": self._raw, + "keep_alive": self._keep_alive, + } + + @property + def api_key(self) -> str: + """API key property definition.""" + return self._api_key + + @property + def model(self) -> str: + """Model property definition.""" + return self._model + + @property + def api_base(self) -> str | None: + """API base property definition.""" + result = non_blank(self._api_base) + # Remove trailing slash + return result[:-1] if result and result.endswith("/") else result + + @property + def api_version(self) -> str | None: + """API version property definition.""" + return non_blank(self._api_version) + + @property + def organization(self) -> str | None: + """Organization property definition.""" + return non_blank(self._organization) + + @property + def n(self) -> int | None: + """N property definition.""" + return self._n + + @property + def temperature(self) -> float | None: + """Temperature property definition.""" + return self._temperature + + @property + def frequency_penalty(self) -> float | None: + """Frequency penalty property definition.""" + return self._frequency_penalty + + @property + def top_p(self) -> float | None: + """Top p property definition.""" + return self._top_p + + @property + def stop(self) -> str | None: + """Stop property definition.""" + return self._stop + + @property + def max_retries(self) -> int | None: + """Max retries property definition.""" + return self._max_retries + + @property + def max_retry_wait(self) -> float | None: + """Max retry wait property definition.""" + return self._max_retry_wait + + @property + def request_timeout(self) -> float | None: + """Request timeout property definition.""" + return self._request_timeout + + @property + def model_supports_json(self) -> bool | None: + """Model supports json property definition.""" + return self._model_supports_json + + @property + def tokens_per_minute(self) -> int | None: + """Tokens per minute property definition.""" + return self._tokens_per_minute + + @property + def requests_per_minute(self) -> int | None: + """Requests per minute property definition.""" + return self._requests_per_minute + + @property + def concurrent_requests(self) -> int | None: + """Concurrent requests property definition.""" + return self._concurrent_requests + + @property + def encoding_model(self) -> str | None: + """Encoding model property definition.""" + return non_blank(self._encoding_model) + + @property + def sleep_on_rate_limit_recommendation(self) -> bool | None: + """Whether to sleep for seconds when recommended by 429 errors (azure-specific).""" + return self._sleep_on_rate_limit_recommendation + + @property + def raw_config(self) -> dict: + """Raw config method definition.""" + return self._raw_config + + @property + def format(self) -> str | None: + """The format to return a response in. Currently the only accepted value is json""" + return self._format + + @property + def mirostat(self): + """ + Enable Mirostat sampling for controlling perplexity. + (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + """ + return self._mirostat + + @property + def mirostat_eta(self): + """ + Influences how quickly the algorithm responds to feedback from the generated text. + A lower learning rate will result in slower adjustments, + while a higher learning rate will make the algorithm more responsive. + (Default: 0.1) + """ + return self._mirostat_eta + + @property + def mirostat_tau(self): + """ + Controls the balance between coherence and diversity of the output. + A lower value will result in more focused and coherent text. + (Default: 5.0) + """ + return self._mirostat_tau + + @property + def num_ctx(self): + """Sets the size of the context window used to generate the next token. (Default: 2048)""" + return self._num_ctx + + @property + def repeat_last_n(self): + """ + Sets how far back for the model to look back to prevent repetition. + (Default: 64, 0 = disabled, -1 = num_ctx) + """ + return self._repeat_last_n + + @property + def repeat_penalty(self): + """ + Sets how strongly to penalize repetitions. + A higher value (e.g., 1.5) will penalize repetitions more strongly, + while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) + """ + return self._repeat_penalty + + @property + def seed(self): + """ + Sets the random number seed to use for generation. + Setting this to a specific number will make the model generate the same text for the same prompt. + (Default: 0) + """ + return self._seed + + @property + def tfs_z(self): + """ + Tail free sampling is used to reduce the impact of less probable tokens from the output. + A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. + (default: 1) + """ + return self._tfs_z + + @property + def num_predict(self): + """ + Maximum number of tokens to predict when generating text. + (Default: 128, -1 = infinite generation, -2 = fill context) + """ + return self._num_predict + + @property + def top_k(self): + """ + Reduces the probability of generating nonsense. + A higher value (e.g. 100) will give more diverse answers, + while a lower value (e.g. 10) will be more conservative. + (Default: 40) + """ + return self._top_k + + @property + def min_p(self): + """Alternative to the top_p, and aims to ensure a balance of quality and variety. + The parameter p represents the minimum probability for a token to be considered, + relative to the probability of the most likely token. + For example, with p=0.05 and the most likely token having a probability of 0.9, + logits with a value less than 0.045 are filtered out. + (Default: 0.0) + """ + return self._min_p + + @property + def suffix(self): + """See https://github.com/ollama/ollama/blob/main/docs/modelfile.md#template""" + return self._suffix + + @property + def system(self): + """The SYSTEM instruction specifies the system message to be used in the template, if applicable.""" + return self._system + + @property + def template(self): + """See https://github.com/ollama/ollama/blob/main/docs/modelfile.md#template""" + return self._template + + @property + def raw(self): + """ + If true no formatting will be applied to the prompt. + You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API + """ + return self._raw + + @property + def keep_alive(self): + """ + Controls how long the model will stay loaded into memory following the request. + (default: 5m) + """ + return self._keep_alive + + @property + def stream(self): + """ + If false the response will be returned as a single response object, rather than a stream of objects. + (default: True) + """ + return self._stream + + @property + def options(self) -> dict: + """Additional model parameters listed in the documentation for the Modelfile such as temperature""" + return non_none_value_key( + { + "n": self.n, + "temperature": self.temperature, + "top_p": self.top_p, + "format": self.format, + "stop": self.stop, + "mirostat": self.mirostat, + "mirostat_eta": self.mirostat_eta, + "mirostat_tau": self.mirostat_tau, + "num_ctx": self.num_ctx, + "repeat_last_n": self.repeat_last_n, + "repeat_penalty": self.repeat_penalty, + "frequency_penalty": self.frequency_penalty, + "seed": self.seed, + "tfs_z": self.tfs_z, + "num_predict": self.num_predict, + "top_k": self.top_k, + "min_p": self.min_p, + "suffix": self.suffix, + "system": self.system, + "template": self.template, + "raw": self.raw, + "keep_alive": self.keep_alive, + } + ) + + def lookup(self, name: str, default_value: Any = None) -> Any: + """Lookup method definition.""" + return self._raw_config.get(name, default_value) + + def get_completion_cache_args(self): + """Get the cache arguments for a completion(generate) LLM.""" + return non_none_value_key( + { + "model": self.model, + "suffix": self.suffix, + "format": self.format, + "system": self.system, + "template": self.template, + # "context": self.context, + "options": self.options, + "stream": self.stream, + "raw": self.raw, + "keep_alive": self.keep_alive, + } + ) + + def get_chat_cache_args(self) -> dict: + """Get the cache arguments for a chat LLM.""" + return non_none_value_key( + { + "model": self.model, + "format": self.format, + "options": self.options, + "stream": self.stream, + "keep_alive": self.keep_alive, + } + ) + + def __str__(self) -> str: + """Str method definition.""" + return json.dumps(self.raw_config, indent=4) + + def __repr__(self) -> str: + """Repr method definition.""" + return f"OpenAIConfiguration({self._raw_config})" + + def __eq__(self, other: object) -> bool: + """Eq method definition.""" + if not isinstance(other, OllamaConfiguration): + return False + return self._raw_config == other._raw_config + + def __hash__(self) -> int: + """Hash method definition.""" + return hash(tuple(sorted(self._raw_config.items()))) diff --git a/graphrag/llm/ollama/ollama_embeddings_llm.py b/graphrag/llm/ollama/ollama_embeddings_llm.py new file mode 100644 index 0000000000..a223e356b5 --- /dev/null +++ b/graphrag/llm/ollama/ollama_embeddings_llm.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The EmbeddingsLLM class.""" + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + EmbeddingInput, + EmbeddingOutput, + LLMInput, +) + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + + +class OllamaEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]): + """A text-embedding generator LLM.""" + + _client: OllamaClientType + _configuration: OllamaConfiguration + + def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: EmbeddingInput, **kwargs: Unpack[LLMInput] + ) -> EmbeddingOutput | None: + args = { + "model": self.configuration.model, + **(kwargs.get("model_parameters") or {}), + } + embedding = await self.client.embed( + input=input, + **args, + ) + return embedding["embeddings"] diff --git a/graphrag/llm/ollama/types.py b/graphrag/llm/ollama/types.py new file mode 100644 index 0000000000..b3719fce95 --- /dev/null +++ b/graphrag/llm/ollama/types.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A base class for OpenAI-based LLMs.""" + +from ollama import AsyncClient + +OllamaClientType = AsyncClient diff --git a/graphrag/llm/openai/factories.py b/graphrag/llm/openai/factories.py index e595e2e55b..9cc68810da 100644 --- a/graphrag/llm/openai/factories.py +++ b/graphrag/llm/openai/factories.py @@ -16,6 +16,12 @@ LLMInvocationFn, OnCacheActionFn, ) +from graphrag.llm.utils import ( + RATE_LIMIT_ERRORS, + RETRYABLE_ERRORS, + get_sleep_time_from_error, + get_token_counter, +) from .json_parsing_llm import JsonParsingLLM from .openai_chat_llm import OpenAIChatLLM @@ -25,13 +31,6 @@ from .openai_history_tracking_llm import OpenAIHistoryTrackingLLM from .openai_token_replacing_llm import OpenAITokenReplacingLLM from .types import OpenAIClientTypes -from .utils import ( - RATE_LIMIT_ERRORS, - RETRYABLE_ERRORS, - get_completion_cache_args, - get_sleep_time_from_error, - get_token_counter, -) def create_openai_chat_llm( @@ -133,7 +132,7 @@ def _cached( on_cache_hit: OnCacheActionFn | None, on_cache_miss: OnCacheActionFn | None, ): - cache_args = get_completion_cache_args(config) + cache_args = config.get_completion_cache_args() result = CachingLLM(delegate, cache_args, operation, cache) result.on_cache_hit(on_cache_hit) result.on_cache_miss(on_cache_miss) diff --git a/graphrag/llm/openai/json_parsing_llm.py b/graphrag/llm/openai/json_parsing_llm.py index 009c1da42e..588a0480c1 100644 --- a/graphrag/llm/openai/json_parsing_llm.py +++ b/graphrag/llm/openai/json_parsing_llm.py @@ -14,7 +14,7 @@ LLMOutput, ) -from .utils import try_parse_json_object +from graphrag.llm.utils import try_parse_json_object class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): diff --git a/graphrag/llm/openai/openai_chat_llm.py b/graphrag/llm/openai/openai_chat_llm.py index bd821ac661..1953f838b5 100644 --- a/graphrag/llm/openai/openai_chat_llm.py +++ b/graphrag/llm/openai/openai_chat_llm.py @@ -14,14 +14,14 @@ LLMInput, LLMOutput, ) +from graphrag.llm.utils import ( + get_completion_llm_args, + try_parse_json_object, +) from ._prompts import JSON_CHECK_PROMPT from .openai_configuration import OpenAIConfiguration from .types import OpenAIClientTypes -from .utils import ( - get_completion_llm_args, - try_parse_json_object, -) log = logging.getLogger(__name__) diff --git a/graphrag/llm/openai/openai_completion_llm.py b/graphrag/llm/openai/openai_completion_llm.py index 74511c02a2..3c31cc9aeb 100644 --- a/graphrag/llm/openai/openai_completion_llm.py +++ b/graphrag/llm/openai/openai_completion_llm.py @@ -13,10 +13,10 @@ CompletionOutput, LLMInput, ) +from graphrag.llm.utils import get_completion_llm_args from .openai_configuration import OpenAIConfiguration from .types import OpenAIClientTypes -from .utils import get_completion_llm_args log = logging.getLogger(__name__) diff --git a/graphrag/llm/openai/openai_configuration.py b/graphrag/llm/openai/openai_configuration.py index cbcc54093d..3309a0c7ad 100644 --- a/graphrag/llm/openai/openai_configuration.py +++ b/graphrag/llm/openai/openai_configuration.py @@ -8,13 +8,7 @@ from typing import Any, cast from graphrag.llm.types import LLMConfig - - -def _non_blank(value: str | None) -> str | None: - if value is None: - return None - stripped = value.strip() - return None if stripped == "" else value +from graphrag.llm.utils import non_blank class OpenAIConfiguration(Hashable, LLMConfig): @@ -141,34 +135,34 @@ def model(self) -> str: @property def deployment_name(self) -> str | None: """Deployment name property definition.""" - return _non_blank(self._deployment_name) + return non_blank(self._deployment_name) @property def api_base(self) -> str | None: """API base property definition.""" - result = _non_blank(self._api_base) + result = non_blank(self._api_base) # Remove trailing slash return result[:-1] if result and result.endswith("/") else result @property def api_version(self) -> str | None: """API version property definition.""" - return _non_blank(self._api_version) + return non_blank(self._api_version) @property def audience(self) -> str | None: """API version property definition.""" - return _non_blank(self._audience) + return non_blank(self._audience) @property def organization(self) -> str | None: """Organization property definition.""" - return _non_blank(self._organization) + return non_blank(self._organization) @property def proxy(self) -> str | None: """Proxy property definition.""" - return _non_blank(self._proxy) + return non_blank(self._proxy) @property def n(self) -> int | None: @@ -203,7 +197,7 @@ def max_tokens(self) -> int | None: @property def response_format(self) -> str | None: """Response format property definition.""" - return _non_blank(self._response_format) + return non_blank(self._response_format) @property def logit_bias(self) -> dict[str, float] | None: @@ -253,7 +247,7 @@ def concurrent_requests(self) -> int | None: @property def encoding_model(self) -> str | None: """Encoding model property definition.""" - return _non_blank(self._encoding_model) + return non_blank(self._encoding_model) @property def sleep_on_rate_limit_recommendation(self) -> bool | None: @@ -269,6 +263,18 @@ def lookup(self, name: str, default_value: Any = None) -> Any: """Lookup method definition.""" return self._raw_config.get(name, default_value) + def get_completion_cache_args(self) -> dict: + """Get the cache arguments for a completion LLM.""" + return { + "model": self.model, + "temperature": self.temperature, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + "n": self.n, + } + def __str__(self) -> str: """Str method definition.""" return json.dumps(self.raw_config, indent=4) diff --git a/graphrag/llm/openai/openai_token_replacing_llm.py b/graphrag/llm/openai/openai_token_replacing_llm.py index 7385b84059..0b90835b22 100644 --- a/graphrag/llm/openai/openai_token_replacing_llm.py +++ b/graphrag/llm/openai/openai_token_replacing_llm.py @@ -14,7 +14,7 @@ LLMOutput, ) -from .utils import perform_variable_replacements +from graphrag.llm.utils import perform_variable_replacements class OpenAITokenReplacingLLM(LLM[CompletionInput, CompletionOutput]): diff --git a/graphrag/llm/types/llm_config.py b/graphrag/llm/types/llm_config.py index cd7ec255b2..c3e4db3a42 100644 --- a/graphrag/llm/types/llm_config.py +++ b/graphrag/llm/types/llm_config.py @@ -33,3 +33,7 @@ def tokens_per_minute(self) -> int | None: def requests_per_minute(self) -> int | None: """Get the number of requests per minute.""" ... + + def get_completion_cache_args(self) -> dict: + """Get the cache arguments for a completion LLM.""" + ... diff --git a/graphrag/llm/openai/utils.py b/graphrag/llm/utils.py similarity index 84% rename from graphrag/llm/openai/utils.py rename to graphrag/llm/utils.py index 64b7118d9b..ff3df40a14 100644 --- a/graphrag/llm/openai/utils.py +++ b/graphrag/llm/utils.py @@ -17,7 +17,7 @@ RateLimitError, ) -from .openai_configuration import OpenAIConfiguration +from .types import LLMConfig DEFAULT_ENCODING = "cl100k_base" @@ -33,7 +33,7 @@ log = logging.getLogger(__name__) -def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]: +def get_token_counter(config: LLMConfig) -> Callable[[str], int]: """Get a function that counts the number of tokens in a string.""" model = config.encoding_model or "cl100k_base" enc = _encoders.get(model) @@ -66,25 +66,12 @@ def replace_all(input: str) -> str: return result -def get_completion_cache_args(configuration: OpenAIConfiguration) -> dict: - """Get the cache arguments for a completion LLM.""" - return { - "model": configuration.model, - "temperature": configuration.temperature, - "frequency_penalty": configuration.frequency_penalty, - "presence_penalty": configuration.presence_penalty, - "top_p": configuration.top_p, - "max_tokens": configuration.max_tokens, - "n": configuration.n, - } - - def get_completion_llm_args( - parameters: dict | None, configuration: OpenAIConfiguration + parameters: dict | None, configuration: LLMConfig ) -> dict: """Get the arguments for a completion LLM.""" return { - **get_completion_cache_args(configuration), + **configuration.get_completion_cache_args(), **(parameters or {}), } @@ -158,3 +145,17 @@ def get_sleep_time_from_error(e: Any) -> float: _please_retry_after = "Please retry after " + + +def non_blank(value: str | None) -> str | None: + if value is None: + return None + stripped = value.strip() + return None if stripped == "" else value + + +def non_none_value_key(data: dict | None) -> dict: + """Remove key from dict where value is None""" + if data is None: + return {} + return {k: v for k, v in data.items() if v is not None} diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index 5945ab9e98..22d343293e 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -15,7 +15,7 @@ import tiktoken from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback -from graphrag.llm.openai.utils import try_parse_json_object +from graphrag.llm.utils import try_parse_json_object from graphrag.query.context_builder.builders import GlobalContextBuilder from graphrag.query.context_builder.conversation_history import ( ConversationHistory, diff --git a/pyproject.toml b/pyproject.toml index 1056f77ae4..15f7b57cea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ azure-identity = "^1.17.1" json-repair = "^0.30.0" future = "^1.0.0" # Needed until graspologic fixes their dependency +ollama = "^0.3.3" [tool.poetry.group.dev.dependencies] coverage = "^7.6.0" From 5c9dcdb201b4665948843d8a4523c4d2ded73aa5 Mon Sep 17 00:00:00 2001 From: L1u <932044860@qq.com> Date: Thu, 24 Oct 2024 12:35:19 +0800 Subject: [PATCH 2/7] remove useless print code --- graphrag/index/llm/load_llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index f3d58aed2a..1b6682b364 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -52,7 +52,6 @@ def load_llm( ) -> CompletionLLM: """Load the LLM for the entity extraction chain.""" on_error = _create_error_handler(callbacks) - print(llm_type.value) if llm_type in loaders: if chat_only and not loaders[llm_type]["chat"]: msg = f"LLM type {llm_type} does not support chat" From a98ae6a48d49d7d75c6061349f41688650115cd9 Mon Sep 17 00:00:00 2001 From: L1u <932044860@qq.com> Date: Sat, 26 Oct 2024 16:21:15 +0800 Subject: [PATCH 3/7] resolve generate community report error --- graphrag/llm/ollama/ollama_chat_llm.py | 40 ++++++++++++++++++--- graphrag/llm/ollama/ollama_configuration.py | 2 +- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/graphrag/llm/ollama/ollama_chat_llm.py b/graphrag/llm/ollama/ollama_chat_llm.py index a9ba70596b..5f61b97d76 100644 --- a/graphrag/llm/ollama/ollama_chat_llm.py +++ b/graphrag/llm/ollama/ollama_chat_llm.py @@ -14,6 +14,7 @@ LLMInput, LLMOutput, ) +from graphrag.llm.utils import try_parse_json_object from .ollama_configuration import OllamaConfiguration from .types import OllamaClientType @@ -39,7 +40,6 @@ async def _execute_llm( ) -> CompletionOutput | None: args = { **self.configuration.get_chat_cache_args(), - **(kwargs.get("model_parameters") or {}), } history = kwargs.get("history") or [] messages = [ @@ -52,9 +52,39 @@ async def _execute_llm( return completion["message"]["content"] async def _invoke_json( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], ) -> LLMOutput[CompletionOutput]: """Generate JSON output.""" - pass + name = kwargs.get("name") or "unknown" + is_response_valid = kwargs.get("is_response_valid") or (lambda _x: True) + + async def generate( + attempt: int | None = None, + ) -> LLMOutput[CompletionOutput]: + call_name = name if attempt is None else f"{name}@{attempt}" + result = await self._invoke(input, **{**kwargs, "name": call_name}) + print("output:\n", result) + output, json_output = try_parse_json_object(result.output or "") + + return LLMOutput[CompletionOutput]( + output=output, + json=json_output, + history=result.history, + ) + + def is_valid(x: dict | None) -> bool: + return x is not None and is_response_valid(x) + + result = await generate() + retry = 0 + while not is_valid(result.json) and retry < _MAX_GENERATION_RETRIES: + result = await generate(retry) + retry += 1 + + if is_valid(result.json): + return result + + error_msg = f"{FAILED_TO_CREATE_JSON_ERROR} - Faulty JSON: {result.json!s}" + raise RuntimeError(error_msg) diff --git a/graphrag/llm/ollama/ollama_configuration.py b/graphrag/llm/ollama/ollama_configuration.py index 468b3c5b33..7a0fd9db62 100644 --- a/graphrag/llm/ollama/ollama_configuration.py +++ b/graphrag/llm/ollama/ollama_configuration.py @@ -116,7 +116,7 @@ def lookup_bool(key: str) -> bool | None: self._mirostat = lookup_int("mirostat") self._mirostat_eta = lookup_float("mirostat_eta") self._mirostat_tau = lookup_float("mirostat_tau") - self._num_ctx = lookup_int("num_ctx") + self._num_ctx = lookup_int("max_tokens") self._repeat_last_n = lookup_int("repeat_last_n") self._repeat_penalty = lookup_float("repeat_penalty") self._frequency_penalty = lookup_float("frequency_penalty") From 82d3a073653a9bbf2e26bc7ffa444945096e69d7 Mon Sep 17 00:00:00 2001 From: L1u <932044860@qq.com> Date: Thu, 24 Oct 2024 11:41:07 +0800 Subject: [PATCH 4/7] ollama support. --- graphrag/config/enums.py | 3 + graphrag/index/llm/load_llm.py | 110 +++- graphrag/llm/__init__.py | 21 + graphrag/llm/ollama/__init__.py | 29 ++ graphrag/llm/ollama/create_ollama_client.py | 30 ++ graphrag/llm/ollama/factories.py | 139 +++++ graphrag/llm/ollama/json_parsing_llm.py | 38 ++ graphrag/llm/ollama/ollama_chat_llm.py | 60 +++ graphrag/llm/ollama/ollama_completion_llm.py | 44 ++ graphrag/llm/ollama/ollama_configuration.py | 493 ++++++++++++++++++ graphrag/llm/ollama/ollama_embeddings_llm.py | 40 ++ graphrag/llm/ollama/types.py | 8 + graphrag/llm/openai/factories.py | 15 +- graphrag/llm/openai/json_parsing_llm.py | 2 +- graphrag/llm/openai/openai_chat_llm.py | 8 +- graphrag/llm/openai/openai_completion_llm.py | 2 +- graphrag/llm/openai/openai_configuration.py | 36 +- .../llm/openai/openai_token_replacing_llm.py | 2 +- graphrag/llm/types/llm_config.py | 4 + graphrag/llm/{openai => }/utils.py | 35 +- .../structured_search/global_search/search.py | 2 +- pyproject.toml | 1 + 22 files changed, 1071 insertions(+), 51 deletions(-) create mode 100644 graphrag/llm/ollama/__init__.py create mode 100644 graphrag/llm/ollama/create_ollama_client.py create mode 100644 graphrag/llm/ollama/factories.py create mode 100644 graphrag/llm/ollama/json_parsing_llm.py create mode 100644 graphrag/llm/ollama/ollama_chat_llm.py create mode 100644 graphrag/llm/ollama/ollama_completion_llm.py create mode 100644 graphrag/llm/ollama/ollama_configuration.py create mode 100644 graphrag/llm/ollama/ollama_embeddings_llm.py create mode 100644 graphrag/llm/ollama/types.py rename graphrag/llm/{openai => }/utils.py (84%) diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 8741cf74ae..4410b272c4 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -98,14 +98,17 @@ class LLMType(str, Enum): # Embeddings OpenAIEmbedding = "openai_embedding" AzureOpenAIEmbedding = "azure_openai_embedding" + OllamaEmbedding = "ollama_embedding" # Raw Completion OpenAI = "openai" AzureOpenAI = "azure_openai" + Ollama = "ollama" # Chat Completion OpenAIChat = "openai_chat" AzureOpenAIChat = "azure_openai_chat" + OllamaChat = "ollama_chat" # Debug StaticResponse = "static_response" diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index a7eda31a4e..f3d58aed2a 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -22,6 +22,12 @@ create_openai_completion_llm, create_openai_embedding_llm, create_tpm_rpm_limiters, + OllamaConfiguration, + create_ollama_client, + create_ollama_chat_llm, + create_ollama_embedding_llm, + create_ollama_completion_llm, + LLMConfig, ) if TYPE_CHECKING: @@ -46,7 +52,7 @@ def load_llm( ) -> CompletionLLM: """Load the LLM for the entity extraction chain.""" on_error = _create_error_handler(callbacks) - + print(llm_type.value) if llm_type in loaders: if chat_only and not loaders[llm_type]["chat"]: msg = f"LLM type {llm_type} does not support chat" @@ -182,6 +188,50 @@ def _load_azure_openai_embeddings_llm( return _load_openai_embeddings_llm(on_error, cache, config, True) +def _load_ollama_completion_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], +): + return _create_ollama_completion_llm( + OllamaConfiguration({ + **_get_base_config(config), + }), + on_error, + cache, + ) + + +def _load_ollama_chat_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], +): + return _create_ollama_chat_llm( + OllamaConfiguration({ + # Set default values + **_get_base_config(config), + }), + on_error, + cache, + ) + + +def _load_ollama_embeddings_llm( + on_error: ErrorHandlerFn, + cache: LLMCache, + config: dict[str, Any], +): + # TODO: Inject Cache + return _create_ollama_embeddings_llm( + OllamaConfiguration({ + **_get_base_config(config), + }), + on_error, + cache, + ) + + def _get_base_config(config: dict[str, Any]) -> dict[str, Any]: api_key = config.get("api_key") @@ -218,6 +268,10 @@ def _load_static_response( "load": _load_azure_openai_completion_llm, "chat": False, }, + LLMType.Ollama: { + "load": _load_ollama_completion_llm, + "chat": False, + }, LLMType.OpenAIChat: { "load": _load_openai_chat_llm, "chat": True, @@ -226,6 +280,10 @@ def _load_static_response( "load": _load_azure_openai_chat_llm, "chat": True, }, + LLMType.OllamaChat: { + "load": _load_ollama_chat_llm, + "chat": True, + }, LLMType.OpenAIEmbedding: { "load": _load_openai_embeddings_llm, "chat": False, @@ -234,6 +292,10 @@ def _load_static_response( "load": _load_azure_openai_embeddings_llm, "chat": False, }, + LLMType.OllamaEmbedding: { + "load": _load_ollama_embeddings_llm, + "chat": False, + }, LLMType.StaticResponse: { "load": _load_static_response, "chat": False, @@ -286,7 +348,49 @@ def _create_openai_embeddings_llm( ) -def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter: +def _create_ollama_chat_llm( + configuration: OllamaConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, +) -> CompletionLLM: + """Create an Ollama chat llm.""" + client = create_ollama_client(configuration=configuration) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_ollama_chat_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_ollama_completion_llm( + configuration: OllamaConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, +) -> CompletionLLM: + """Create an Ollama completion llm.""" + client = create_ollama_client(configuration=configuration) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_ollama_completion_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_ollama_embeddings_llm( + configuration: OllamaConfiguration, + on_error: ErrorHandlerFn, + cache: LLMCache, +) -> EmbeddingLLM: + """Create an Ollama embeddings llm.""" + client = create_ollama_client(configuration=configuration) + limiter = _create_limiter(configuration) + semaphore = _create_semaphore(configuration) + return create_ollama_embedding_llm( + client, configuration, cache, limiter, semaphore, on_error=on_error + ) + + +def _create_limiter(configuration: LLMConfig) -> LLMLimiter: limit_name = configuration.model or configuration.deployment_name or "default" if limit_name not in _rate_limiters: tpm = configuration.tokens_per_minute @@ -296,7 +400,7 @@ def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter: return _rate_limiters[limit_name] -def _create_semaphore(configuration: OpenAIConfiguration) -> asyncio.Semaphore | None: +def _create_semaphore(configuration: LLMConfig) -> asyncio.Semaphore | None: limit_name = configuration.model or configuration.deployment_name or "default" concurrency = configuration.concurrent_requests diff --git a/graphrag/llm/__init__.py b/graphrag/llm/__init__.py index 609be951b2..508600ee48 100644 --- a/graphrag/llm/__init__.py +++ b/graphrag/llm/__init__.py @@ -42,6 +42,17 @@ LLMOutput, OnCacheActionFn, ) +from .ollama import ( + OllamaChatLLM, + OllamaClientType, + OllamaCompletionLLM, + OllamaConfiguration, + OllamaEmbeddingsLLM, + create_ollama_chat_llm, + create_ollama_client, + create_ollama_completion_llm, + create_ollama_embedding_llm, +) __all__ = [ # LLM Types @@ -79,6 +90,12 @@ "OpenAIConfiguration", "OpenAIEmbeddingsLLM", "RateLimitingLLM", + # Ollama + "OllamaChatLLM", + "OllamaClientType", + "OllamaCompletionLLM", + "OllamaConfiguration", + "OllamaEmbeddingsLLM", # Errors "RetriesExhaustedError", "TpmRpmLLMLimiter", @@ -86,6 +103,10 @@ "create_openai_client", "create_openai_completion_llm", "create_openai_embedding_llm", + "create_ollama_chat_llm", + "create_ollama_client", + "create_ollama_completion_llm", + "create_ollama_embedding_llm", # Limiters "create_tpm_rpm_limiters", ] diff --git a/graphrag/llm/ollama/__init__.py b/graphrag/llm/ollama/__init__.py new file mode 100644 index 0000000000..adf27d50a5 --- /dev/null +++ b/graphrag/llm/ollama/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Ollama LLM implementations.""" + +from .create_ollama_client import create_ollama_client +from .factories import ( + create_ollama_chat_llm, + create_ollama_completion_llm, + create_ollama_embedding_llm, +) +from .ollama_chat_llm import OllamaChatLLM +from .ollama_completion_llm import OllamaCompletionLLM +from .ollama_configuration import OllamaConfiguration +from .ollama_embeddings_llm import OllamaEmbeddingsLLM +from .types import OllamaClientType + + +__all__ = [ + "OllamaChatLLM", + "OllamaClientType", + "OllamaCompletionLLM", + "OllamaConfiguration", + "OllamaEmbeddingsLLM", + "create_ollama_chat_llm", + "create_ollama_client", + "create_ollama_completion_llm", + "create_ollama_embedding_llm", +] diff --git a/graphrag/llm/ollama/create_ollama_client.py b/graphrag/llm/ollama/create_ollama_client.py new file mode 100644 index 0000000000..87f5902c44 --- /dev/null +++ b/graphrag/llm/ollama/create_ollama_client.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Create OpenAI client instance.""" + +import logging +from functools import cache + +from ollama import AsyncClient + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + +log = logging.getLogger(__name__) + +API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client" + + +@cache +def create_ollama_client( + configuration: OllamaConfiguration +) -> OllamaClientType: + """Create a new Ollama client instance.""" + + log.info("Creating OpenAI client base_url=%s", configuration.api_base) + return AsyncClient( + host=configuration.api_base, + # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here + timeout=configuration.request_timeout or 180.0, + ) diff --git a/graphrag/llm/ollama/factories.py b/graphrag/llm/ollama/factories.py new file mode 100644 index 0000000000..e203cf48a1 --- /dev/null +++ b/graphrag/llm/ollama/factories.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Factory functions for creating OpenAI LLMs.""" + +import asyncio + +from graphrag.llm.base import CachingLLM, RateLimitingLLM +from graphrag.llm.limiting import LLMLimiter +from graphrag.llm.types import ( + LLM, + CompletionLLM, + EmbeddingLLM, + ErrorHandlerFn, + LLMCache, + LLMInvocationFn, + OnCacheActionFn, +) +from graphrag.llm.utils import ( + RATE_LIMIT_ERRORS, + RETRYABLE_ERRORS, + get_sleep_time_from_error, + get_token_counter, +) +from graphrag.llm.openai.openai_history_tracking_llm import OpenAIHistoryTrackingLLM +from graphrag.llm.openai.openai_token_replacing_llm import OpenAITokenReplacingLLM + +from .json_parsing_llm import JsonParsingLLM +from .ollama_chat_llm import OllamaChatLLM +from .ollama_completion_llm import OllamaCompletionLLM +from .ollama_configuration import OllamaConfiguration +from .ollama_embeddings_llm import OllamaEmbeddingsLLM +from .types import OllamaClientType + + +def create_ollama_chat_llm( + client: OllamaClientType, + config: OllamaConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI chat LLM.""" + operation = "chat" + result = OllamaChatLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + result = OpenAIHistoryTrackingLLM(result) + result = OpenAITokenReplacingLLM(result) + return JsonParsingLLM(result) + + +def create_ollama_completion_llm( + client: OllamaClientType, + config: OllamaConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> CompletionLLM: + """Create an OpenAI completion LLM.""" + operation = "completion" + result = OllamaCompletionLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return OpenAITokenReplacingLLM(result) + + +def create_ollama_embedding_llm( + client: OllamaClientType, + config: OllamaConfiguration, + cache: LLMCache | None = None, + limiter: LLMLimiter | None = None, + semaphore: asyncio.Semaphore | None = None, + on_invoke: LLMInvocationFn | None = None, + on_error: ErrorHandlerFn | None = None, + on_cache_hit: OnCacheActionFn | None = None, + on_cache_miss: OnCacheActionFn | None = None, +) -> EmbeddingLLM: + """Create an OpenAI embeddings LLM.""" + operation = "embedding" + result = OllamaEmbeddingsLLM(client, config) + result.on_error(on_error) + if limiter is not None or semaphore is not None: + result = _rate_limited(result, config, operation, limiter, semaphore, on_invoke) + if cache is not None: + result = _cached(result, config, operation, cache, on_cache_hit, on_cache_miss) + return result + + +def _rate_limited( + delegate: LLM, + config: OllamaConfiguration, + operation: str, + limiter: LLMLimiter | None, + semaphore: asyncio.Semaphore | None, + on_invoke: LLMInvocationFn | None, +): + result = RateLimitingLLM( + delegate, + config, + operation, + RETRYABLE_ERRORS, + RATE_LIMIT_ERRORS, + limiter, + semaphore, + get_token_counter(config), + get_sleep_time_from_error, + ) + result.on_invoke(on_invoke) + return result + + +def _cached( + delegate: LLM, + config: OllamaConfiguration, + operation: str, + cache: LLMCache, + on_cache_hit: OnCacheActionFn | None, + on_cache_miss: OnCacheActionFn | None, +): + cache_args = config.get_completion_cache_args() + result = CachingLLM(delegate, cache_args, operation, cache) + result.on_cache_hit(on_cache_hit) + result.on_cache_miss(on_cache_miss) + return result diff --git a/graphrag/llm/ollama/json_parsing_llm.py b/graphrag/llm/ollama/json_parsing_llm.py new file mode 100644 index 0000000000..588a0480c1 --- /dev/null +++ b/graphrag/llm/ollama/json_parsing_llm.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""An LLM that unpacks cached JSON responses.""" + +from typing_extensions import Unpack + +from graphrag.llm.types import ( + LLM, + CompletionInput, + CompletionLLM, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from graphrag.llm.utils import try_parse_json_object + + +class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): + """An OpenAI History-Tracking LLM.""" + + _delegate: CompletionLLM + + def __init__(self, delegate: CompletionLLM): + self._delegate = delegate + + async def __call__( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Call the LLM with the input and kwargs.""" + result = await self._delegate(input, **kwargs) + if kwargs.get("json") and result.json is None and result.output is not None: + _, parsed_json = try_parse_json_object(result.output) + result.json = parsed_json + return result diff --git a/graphrag/llm/ollama/ollama_chat_llm.py b/graphrag/llm/ollama/ollama_chat_llm.py new file mode 100644 index 0000000000..a9ba70596b --- /dev/null +++ b/graphrag/llm/ollama/ollama_chat_llm.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Chat-based language model.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, + LLMOutput, +) + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + +log = logging.getLogger(__name__) + +_MAX_GENERATION_RETRIES = 3 +FAILED_TO_CREATE_JSON_ERROR = "Failed to generate valid JSON output" + + +class OllamaChatLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A Chat-based LLM.""" + + _client: OllamaClientType + _configuration: OllamaConfiguration + + def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: CompletionInput, **kwargs: Unpack[LLMInput] + ) -> CompletionOutput | None: + args = { + **self.configuration.get_chat_cache_args(), + **(kwargs.get("model_parameters") or {}), + } + history = kwargs.get("history") or [] + messages = [ + *history, + {"role": "user", "content": input}, + ] + completion = await self.client.chat( + messages=messages, **args + ) + return completion["message"]["content"] + + async def _invoke_json( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> LLMOutput[CompletionOutput]: + """Generate JSON output.""" + pass diff --git a/graphrag/llm/ollama/ollama_completion_llm.py b/graphrag/llm/ollama/ollama_completion_llm.py new file mode 100644 index 0000000000..4102418def --- /dev/null +++ b/graphrag/llm/ollama/ollama_completion_llm.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A text-completion based LLM.""" + +import logging + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + CompletionInput, + CompletionOutput, + LLMInput, +) +from graphrag.llm.utils import get_completion_llm_args + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + + +log = logging.getLogger(__name__) + + +class OllamaCompletionLLM(BaseLLM[CompletionInput, CompletionOutput]): + """A text-completion based LLM.""" + + _client: OllamaClientType + _configuration: OllamaConfiguration + + def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], + ) -> CompletionOutput | None: + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) + completion = await self.client.generate(prompt=input, **args) + return completion["response"] diff --git a/graphrag/llm/ollama/ollama_configuration.py b/graphrag/llm/ollama/ollama_configuration.py new file mode 100644 index 0000000000..468b3c5b33 --- /dev/null +++ b/graphrag/llm/ollama/ollama_configuration.py @@ -0,0 +1,493 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Ollama Configuration class definition.""" +import json +from collections.abc import Hashable +from typing import cast, Any + +from graphrag.llm import LLMConfig +from graphrag.llm.utils import non_blank, non_none_value_key + + +class OllamaConfiguration(Hashable, LLMConfig): + """OpenAI Configuration class definition.""" + + # Core Configuration + _api_key: str + _model: str + + _api_base: str | None + _api_version: str | None + _organization: str | None + + # Operation Configuration + _n: int | None + _temperature: float | None + _top_p: float | None + _format: str | None + _stop: str | None + _mirostat: int | None + _mirostat_eta: float | None + _mirostat_tau: float | None + _num_ctx: int | None + _repeat_last_n: int | None + _repeat_penalty: float | None + _frequency_penalty: float | None + _seed: int | None + _tfs_z: float | None + _num_predict: int | None + _top_k: int | None + _min_p: float | None + _options: dict | None + _suffix: str | None + _system: str | None + _template: str | None + _raw: bool | None + _keep_alive: int | None + _stream: bool | None + + + # Retry Logic + _max_retries: int | None + _max_retry_wait: float | None + _request_timeout: float | None + + # The raw configuration object + _raw_config: dict + + # Feature Flags + _model_supports_json: bool | None + + # Custom Configuration + _tokens_per_minute: int | None + _requests_per_minute: int | None + _concurrent_requests: int | None + _encoding_model: str | None + _sleep_on_rate_limit_recommendation: bool | None + + def __init__( + self, + config: dict, + ): + """Init method definition.""" + + def lookup_required(key: str) -> str: + return cast(str, config.get(key)) + + def lookup_str(key: str) -> str | None: + return cast(str | None, config.get(key)) + + def lookup_int(key: str) -> int | None: + result = config.get(key) + if result is None: + return None + return int(cast(int, result)) + + def lookup_float(key: str) -> float | None: + result = config.get(key) + if result is None: + return None + return float(cast(float, result)) + + def lookup_dict(key: str) -> dict | None: + return cast(dict | None, config.get(key)) + + def lookup_list(key: str) -> list | None: + return cast(list | None, config.get(key)) + + def lookup_bool(key: str) -> bool | None: + value = config.get(key) + if isinstance(value, str): + return value.upper() == "TRUE" + if isinstance(value, int): + return value > 0 + return cast(bool | None, config.get(key)) + + self._api_key = lookup_required("api_key") + self._model = lookup_required("model") + self._api_base = lookup_str("api_base") + self._api_version = lookup_str("api_version") + self._organization = lookup_str("organization") + self._n = lookup_int("n") + self._temperature = lookup_float("temperature") + self._top_p = lookup_float("top_p") + self._stop = lookup_str("stop") + self._mirostat = lookup_int("mirostat") + self._mirostat_eta = lookup_float("mirostat_eta") + self._mirostat_tau = lookup_float("mirostat_tau") + self._num_ctx = lookup_int("num_ctx") + self._repeat_last_n = lookup_int("repeat_last_n") + self._repeat_penalty = lookup_float("repeat_penalty") + self._frequency_penalty = lookup_float("frequency_penalty") + self._seed = lookup_int("seed") + self._tfs_z = lookup_float("tfs_z") + self._num_predict = lookup_int("num_predict") + self._top_k = lookup_int("top_k") + self._min_p = lookup_float("min_p") + self._suffix = lookup_str("suffix") + self._system = lookup_str("system") + self._template = lookup_str("template") + self._raw = lookup_bool("raw") + self._keep_alive = lookup_int("keep_alive") + self._stream = lookup_bool("stream") + self._format = lookup_str("response_format") + self._max_retries = lookup_int("max_retries") + self._request_timeout = lookup_float("request_timeout") + self._model_supports_json = lookup_bool("model_supports_json") + self._tokens_per_minute = lookup_int("tokens_per_minute") + self._requests_per_minute = lookup_int("requests_per_minute") + self._concurrent_requests = lookup_int("concurrent_requests") + self._encoding_model = lookup_str("encoding_model") + self._max_retry_wait = lookup_float("max_retry_wait") + self._sleep_on_rate_limit_recommendation = lookup_bool( + "sleep_on_rate_limit_recommendation" + ) + self._raw_config = config + self._options = { + "n": self._n, + "temperature": self._temperature, + "top_p": self._top_p, + "format": self._format, + "stop": self._stop, + "mirostat": self._mirostat, + "mirostat_eta": self._mirostat_eta, + "mirostat_tau": self._mirostat_tau, + "num_ctx": self._num_ctx, + "repeat_last_n": self._repeat_last_n, + "repeat_penalty": self._repeat_penalty, + "frequency_penalty": self._frequency_penalty, + "seed": self._seed, + "tfs_z": self._tfs_z, + "num_predict": self._num_predict, + "top_k": self._top_k, + "min_p": self._min_p, + "suffix": self._suffix, + "system": self._system, + "template": self._template, + "raw": self._raw, + "keep_alive": self._keep_alive, + } + + @property + def api_key(self) -> str: + """API key property definition.""" + return self._api_key + + @property + def model(self) -> str: + """Model property definition.""" + return self._model + + @property + def api_base(self) -> str | None: + """API base property definition.""" + result = non_blank(self._api_base) + # Remove trailing slash + return result[:-1] if result and result.endswith("/") else result + + @property + def api_version(self) -> str | None: + """API version property definition.""" + return non_blank(self._api_version) + + @property + def organization(self) -> str | None: + """Organization property definition.""" + return non_blank(self._organization) + + @property + def n(self) -> int | None: + """N property definition.""" + return self._n + + @property + def temperature(self) -> float | None: + """Temperature property definition.""" + return self._temperature + + @property + def frequency_penalty(self) -> float | None: + """Frequency penalty property definition.""" + return self._frequency_penalty + + @property + def top_p(self) -> float | None: + """Top p property definition.""" + return self._top_p + + @property + def stop(self) -> str | None: + """Stop property definition.""" + return self._stop + + @property + def max_retries(self) -> int | None: + """Max retries property definition.""" + return self._max_retries + + @property + def max_retry_wait(self) -> float | None: + """Max retry wait property definition.""" + return self._max_retry_wait + + @property + def request_timeout(self) -> float | None: + """Request timeout property definition.""" + return self._request_timeout + + @property + def model_supports_json(self) -> bool | None: + """Model supports json property definition.""" + return self._model_supports_json + + @property + def tokens_per_minute(self) -> int | None: + """Tokens per minute property definition.""" + return self._tokens_per_minute + + @property + def requests_per_minute(self) -> int | None: + """Requests per minute property definition.""" + return self._requests_per_minute + + @property + def concurrent_requests(self) -> int | None: + """Concurrent requests property definition.""" + return self._concurrent_requests + + @property + def encoding_model(self) -> str | None: + """Encoding model property definition.""" + return non_blank(self._encoding_model) + + @property + def sleep_on_rate_limit_recommendation(self) -> bool | None: + """Whether to sleep for seconds when recommended by 429 errors (azure-specific).""" + return self._sleep_on_rate_limit_recommendation + + @property + def raw_config(self) -> dict: + """Raw config method definition.""" + return self._raw_config + + @property + def format(self) -> str | None: + """The format to return a response in. Currently the only accepted value is json""" + return self._format + + @property + def mirostat(self): + """ + Enable Mirostat sampling for controlling perplexity. + (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) + """ + return self._mirostat + + @property + def mirostat_eta(self): + """ + Influences how quickly the algorithm responds to feedback from the generated text. + A lower learning rate will result in slower adjustments, + while a higher learning rate will make the algorithm more responsive. + (Default: 0.1) + """ + return self._mirostat_eta + + @property + def mirostat_tau(self): + """ + Controls the balance between coherence and diversity of the output. + A lower value will result in more focused and coherent text. + (Default: 5.0) + """ + return self._mirostat_tau + + @property + def num_ctx(self): + """Sets the size of the context window used to generate the next token. (Default: 2048)""" + return self._num_ctx + + @property + def repeat_last_n(self): + """ + Sets how far back for the model to look back to prevent repetition. + (Default: 64, 0 = disabled, -1 = num_ctx) + """ + return self._repeat_last_n + + @property + def repeat_penalty(self): + """ + Sets how strongly to penalize repetitions. + A higher value (e.g., 1.5) will penalize repetitions more strongly, + while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) + """ + return self._repeat_penalty + + @property + def seed(self): + """ + Sets the random number seed to use for generation. + Setting this to a specific number will make the model generate the same text for the same prompt. + (Default: 0) + """ + return self._seed + + @property + def tfs_z(self): + """ + Tail free sampling is used to reduce the impact of less probable tokens from the output. + A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. + (default: 1) + """ + return self._tfs_z + + @property + def num_predict(self): + """ + Maximum number of tokens to predict when generating text. + (Default: 128, -1 = infinite generation, -2 = fill context) + """ + return self._num_predict + + @property + def top_k(self): + """ + Reduces the probability of generating nonsense. + A higher value (e.g. 100) will give more diverse answers, + while a lower value (e.g. 10) will be more conservative. + (Default: 40) + """ + return self._top_k + + @property + def min_p(self): + """Alternative to the top_p, and aims to ensure a balance of quality and variety. + The parameter p represents the minimum probability for a token to be considered, + relative to the probability of the most likely token. + For example, with p=0.05 and the most likely token having a probability of 0.9, + logits with a value less than 0.045 are filtered out. + (Default: 0.0) + """ + return self._min_p + + @property + def suffix(self): + """See https://github.com/ollama/ollama/blob/main/docs/modelfile.md#template""" + return self._suffix + + @property + def system(self): + """The SYSTEM instruction specifies the system message to be used in the template, if applicable.""" + return self._system + + @property + def template(self): + """See https://github.com/ollama/ollama/blob/main/docs/modelfile.md#template""" + return self._template + + @property + def raw(self): + """ + If true no formatting will be applied to the prompt. + You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API + """ + return self._raw + + @property + def keep_alive(self): + """ + Controls how long the model will stay loaded into memory following the request. + (default: 5m) + """ + return self._keep_alive + + @property + def stream(self): + """ + If false the response will be returned as a single response object, rather than a stream of objects. + (default: True) + """ + return self._stream + + @property + def options(self) -> dict: + """Additional model parameters listed in the documentation for the Modelfile such as temperature""" + return non_none_value_key( + { + "n": self.n, + "temperature": self.temperature, + "top_p": self.top_p, + "format": self.format, + "stop": self.stop, + "mirostat": self.mirostat, + "mirostat_eta": self.mirostat_eta, + "mirostat_tau": self.mirostat_tau, + "num_ctx": self.num_ctx, + "repeat_last_n": self.repeat_last_n, + "repeat_penalty": self.repeat_penalty, + "frequency_penalty": self.frequency_penalty, + "seed": self.seed, + "tfs_z": self.tfs_z, + "num_predict": self.num_predict, + "top_k": self.top_k, + "min_p": self.min_p, + "suffix": self.suffix, + "system": self.system, + "template": self.template, + "raw": self.raw, + "keep_alive": self.keep_alive, + } + ) + + def lookup(self, name: str, default_value: Any = None) -> Any: + """Lookup method definition.""" + return self._raw_config.get(name, default_value) + + def get_completion_cache_args(self): + """Get the cache arguments for a completion(generate) LLM.""" + return non_none_value_key( + { + "model": self.model, + "suffix": self.suffix, + "format": self.format, + "system": self.system, + "template": self.template, + # "context": self.context, + "options": self.options, + "stream": self.stream, + "raw": self.raw, + "keep_alive": self.keep_alive, + } + ) + + def get_chat_cache_args(self) -> dict: + """Get the cache arguments for a chat LLM.""" + return non_none_value_key( + { + "model": self.model, + "format": self.format, + "options": self.options, + "stream": self.stream, + "keep_alive": self.keep_alive, + } + ) + + def __str__(self) -> str: + """Str method definition.""" + return json.dumps(self.raw_config, indent=4) + + def __repr__(self) -> str: + """Repr method definition.""" + return f"OpenAIConfiguration({self._raw_config})" + + def __eq__(self, other: object) -> bool: + """Eq method definition.""" + if not isinstance(other, OllamaConfiguration): + return False + return self._raw_config == other._raw_config + + def __hash__(self) -> int: + """Hash method definition.""" + return hash(tuple(sorted(self._raw_config.items()))) diff --git a/graphrag/llm/ollama/ollama_embeddings_llm.py b/graphrag/llm/ollama/ollama_embeddings_llm.py new file mode 100644 index 0000000000..a223e356b5 --- /dev/null +++ b/graphrag/llm/ollama/ollama_embeddings_llm.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The EmbeddingsLLM class.""" + +from typing_extensions import Unpack + +from graphrag.llm.base import BaseLLM +from graphrag.llm.types import ( + EmbeddingInput, + EmbeddingOutput, + LLMInput, +) + +from .ollama_configuration import OllamaConfiguration +from .types import OllamaClientType + + +class OllamaEmbeddingsLLM(BaseLLM[EmbeddingInput, EmbeddingOutput]): + """A text-embedding generator LLM.""" + + _client: OllamaClientType + _configuration: OllamaConfiguration + + def __init__(self, client: OllamaClientType, configuration: OllamaConfiguration): + self.client = client + self.configuration = configuration + + async def _execute_llm( + self, input: EmbeddingInput, **kwargs: Unpack[LLMInput] + ) -> EmbeddingOutput | None: + args = { + "model": self.configuration.model, + **(kwargs.get("model_parameters") or {}), + } + embedding = await self.client.embed( + input=input, + **args, + ) + return embedding["embeddings"] diff --git a/graphrag/llm/ollama/types.py b/graphrag/llm/ollama/types.py new file mode 100644 index 0000000000..b3719fce95 --- /dev/null +++ b/graphrag/llm/ollama/types.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A base class for OpenAI-based LLMs.""" + +from ollama import AsyncClient + +OllamaClientType = AsyncClient diff --git a/graphrag/llm/openai/factories.py b/graphrag/llm/openai/factories.py index e595e2e55b..9cc68810da 100644 --- a/graphrag/llm/openai/factories.py +++ b/graphrag/llm/openai/factories.py @@ -16,6 +16,12 @@ LLMInvocationFn, OnCacheActionFn, ) +from graphrag.llm.utils import ( + RATE_LIMIT_ERRORS, + RETRYABLE_ERRORS, + get_sleep_time_from_error, + get_token_counter, +) from .json_parsing_llm import JsonParsingLLM from .openai_chat_llm import OpenAIChatLLM @@ -25,13 +31,6 @@ from .openai_history_tracking_llm import OpenAIHistoryTrackingLLM from .openai_token_replacing_llm import OpenAITokenReplacingLLM from .types import OpenAIClientTypes -from .utils import ( - RATE_LIMIT_ERRORS, - RETRYABLE_ERRORS, - get_completion_cache_args, - get_sleep_time_from_error, - get_token_counter, -) def create_openai_chat_llm( @@ -133,7 +132,7 @@ def _cached( on_cache_hit: OnCacheActionFn | None, on_cache_miss: OnCacheActionFn | None, ): - cache_args = get_completion_cache_args(config) + cache_args = config.get_completion_cache_args() result = CachingLLM(delegate, cache_args, operation, cache) result.on_cache_hit(on_cache_hit) result.on_cache_miss(on_cache_miss) diff --git a/graphrag/llm/openai/json_parsing_llm.py b/graphrag/llm/openai/json_parsing_llm.py index 009c1da42e..588a0480c1 100644 --- a/graphrag/llm/openai/json_parsing_llm.py +++ b/graphrag/llm/openai/json_parsing_llm.py @@ -14,7 +14,7 @@ LLMOutput, ) -from .utils import try_parse_json_object +from graphrag.llm.utils import try_parse_json_object class JsonParsingLLM(LLM[CompletionInput, CompletionOutput]): diff --git a/graphrag/llm/openai/openai_chat_llm.py b/graphrag/llm/openai/openai_chat_llm.py index bd821ac661..1953f838b5 100644 --- a/graphrag/llm/openai/openai_chat_llm.py +++ b/graphrag/llm/openai/openai_chat_llm.py @@ -14,14 +14,14 @@ LLMInput, LLMOutput, ) +from graphrag.llm.utils import ( + get_completion_llm_args, + try_parse_json_object, +) from ._prompts import JSON_CHECK_PROMPT from .openai_configuration import OpenAIConfiguration from .types import OpenAIClientTypes -from .utils import ( - get_completion_llm_args, - try_parse_json_object, -) log = logging.getLogger(__name__) diff --git a/graphrag/llm/openai/openai_completion_llm.py b/graphrag/llm/openai/openai_completion_llm.py index 74511c02a2..3c31cc9aeb 100644 --- a/graphrag/llm/openai/openai_completion_llm.py +++ b/graphrag/llm/openai/openai_completion_llm.py @@ -13,10 +13,10 @@ CompletionOutput, LLMInput, ) +from graphrag.llm.utils import get_completion_llm_args from .openai_configuration import OpenAIConfiguration from .types import OpenAIClientTypes -from .utils import get_completion_llm_args log = logging.getLogger(__name__) diff --git a/graphrag/llm/openai/openai_configuration.py b/graphrag/llm/openai/openai_configuration.py index cbcc54093d..3309a0c7ad 100644 --- a/graphrag/llm/openai/openai_configuration.py +++ b/graphrag/llm/openai/openai_configuration.py @@ -8,13 +8,7 @@ from typing import Any, cast from graphrag.llm.types import LLMConfig - - -def _non_blank(value: str | None) -> str | None: - if value is None: - return None - stripped = value.strip() - return None if stripped == "" else value +from graphrag.llm.utils import non_blank class OpenAIConfiguration(Hashable, LLMConfig): @@ -141,34 +135,34 @@ def model(self) -> str: @property def deployment_name(self) -> str | None: """Deployment name property definition.""" - return _non_blank(self._deployment_name) + return non_blank(self._deployment_name) @property def api_base(self) -> str | None: """API base property definition.""" - result = _non_blank(self._api_base) + result = non_blank(self._api_base) # Remove trailing slash return result[:-1] if result and result.endswith("/") else result @property def api_version(self) -> str | None: """API version property definition.""" - return _non_blank(self._api_version) + return non_blank(self._api_version) @property def audience(self) -> str | None: """API version property definition.""" - return _non_blank(self._audience) + return non_blank(self._audience) @property def organization(self) -> str | None: """Organization property definition.""" - return _non_blank(self._organization) + return non_blank(self._organization) @property def proxy(self) -> str | None: """Proxy property definition.""" - return _non_blank(self._proxy) + return non_blank(self._proxy) @property def n(self) -> int | None: @@ -203,7 +197,7 @@ def max_tokens(self) -> int | None: @property def response_format(self) -> str | None: """Response format property definition.""" - return _non_blank(self._response_format) + return non_blank(self._response_format) @property def logit_bias(self) -> dict[str, float] | None: @@ -253,7 +247,7 @@ def concurrent_requests(self) -> int | None: @property def encoding_model(self) -> str | None: """Encoding model property definition.""" - return _non_blank(self._encoding_model) + return non_blank(self._encoding_model) @property def sleep_on_rate_limit_recommendation(self) -> bool | None: @@ -269,6 +263,18 @@ def lookup(self, name: str, default_value: Any = None) -> Any: """Lookup method definition.""" return self._raw_config.get(name, default_value) + def get_completion_cache_args(self) -> dict: + """Get the cache arguments for a completion LLM.""" + return { + "model": self.model, + "temperature": self.temperature, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + "n": self.n, + } + def __str__(self) -> str: """Str method definition.""" return json.dumps(self.raw_config, indent=4) diff --git a/graphrag/llm/openai/openai_token_replacing_llm.py b/graphrag/llm/openai/openai_token_replacing_llm.py index 7385b84059..0b90835b22 100644 --- a/graphrag/llm/openai/openai_token_replacing_llm.py +++ b/graphrag/llm/openai/openai_token_replacing_llm.py @@ -14,7 +14,7 @@ LLMOutput, ) -from .utils import perform_variable_replacements +from graphrag.llm.utils import perform_variable_replacements class OpenAITokenReplacingLLM(LLM[CompletionInput, CompletionOutput]): diff --git a/graphrag/llm/types/llm_config.py b/graphrag/llm/types/llm_config.py index cd7ec255b2..c3e4db3a42 100644 --- a/graphrag/llm/types/llm_config.py +++ b/graphrag/llm/types/llm_config.py @@ -33,3 +33,7 @@ def tokens_per_minute(self) -> int | None: def requests_per_minute(self) -> int | None: """Get the number of requests per minute.""" ... + + def get_completion_cache_args(self) -> dict: + """Get the cache arguments for a completion LLM.""" + ... diff --git a/graphrag/llm/openai/utils.py b/graphrag/llm/utils.py similarity index 84% rename from graphrag/llm/openai/utils.py rename to graphrag/llm/utils.py index 64b7118d9b..ff3df40a14 100644 --- a/graphrag/llm/openai/utils.py +++ b/graphrag/llm/utils.py @@ -17,7 +17,7 @@ RateLimitError, ) -from .openai_configuration import OpenAIConfiguration +from .types import LLMConfig DEFAULT_ENCODING = "cl100k_base" @@ -33,7 +33,7 @@ log = logging.getLogger(__name__) -def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]: +def get_token_counter(config: LLMConfig) -> Callable[[str], int]: """Get a function that counts the number of tokens in a string.""" model = config.encoding_model or "cl100k_base" enc = _encoders.get(model) @@ -66,25 +66,12 @@ def replace_all(input: str) -> str: return result -def get_completion_cache_args(configuration: OpenAIConfiguration) -> dict: - """Get the cache arguments for a completion LLM.""" - return { - "model": configuration.model, - "temperature": configuration.temperature, - "frequency_penalty": configuration.frequency_penalty, - "presence_penalty": configuration.presence_penalty, - "top_p": configuration.top_p, - "max_tokens": configuration.max_tokens, - "n": configuration.n, - } - - def get_completion_llm_args( - parameters: dict | None, configuration: OpenAIConfiguration + parameters: dict | None, configuration: LLMConfig ) -> dict: """Get the arguments for a completion LLM.""" return { - **get_completion_cache_args(configuration), + **configuration.get_completion_cache_args(), **(parameters or {}), } @@ -158,3 +145,17 @@ def get_sleep_time_from_error(e: Any) -> float: _please_retry_after = "Please retry after " + + +def non_blank(value: str | None) -> str | None: + if value is None: + return None + stripped = value.strip() + return None if stripped == "" else value + + +def non_none_value_key(data: dict | None) -> dict: + """Remove key from dict where value is None""" + if data is None: + return {} + return {k: v for k, v in data.items() if v is not None} diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index 5945ab9e98..22d343293e 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -15,7 +15,7 @@ import tiktoken from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback -from graphrag.llm.openai.utils import try_parse_json_object +from graphrag.llm.utils import try_parse_json_object from graphrag.query.context_builder.builders import GlobalContextBuilder from graphrag.query.context_builder.conversation_history import ( ConversationHistory, diff --git a/pyproject.toml b/pyproject.toml index d454078b6a..f0b43fb4e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ json-repair = "^0.30.0" future = "^1.0.0" # Needed until graspologic fixes their dependency typer = "^0.12.5" +ollama = "^0.3.3" mkdocs-typer = "^0.0.3" [tool.poetry.group.dev.dependencies] From f57f4a36f20832439827e5976113e6f6d77cc289 Mon Sep 17 00:00:00 2001 From: L1u <932044860@qq.com> Date: Thu, 24 Oct 2024 12:35:19 +0800 Subject: [PATCH 5/7] remove useless print code --- graphrag/index/llm/load_llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphrag/index/llm/load_llm.py b/graphrag/index/llm/load_llm.py index f3d58aed2a..1b6682b364 100644 --- a/graphrag/index/llm/load_llm.py +++ b/graphrag/index/llm/load_llm.py @@ -52,7 +52,6 @@ def load_llm( ) -> CompletionLLM: """Load the LLM for the entity extraction chain.""" on_error = _create_error_handler(callbacks) - print(llm_type.value) if llm_type in loaders: if chat_only and not loaders[llm_type]["chat"]: msg = f"LLM type {llm_type} does not support chat" From 7452efc85a37967c0ccc0b85a566a0eb491231a2 Mon Sep 17 00:00:00 2001 From: L1u <932044860@qq.com> Date: Sat, 26 Oct 2024 16:21:15 +0800 Subject: [PATCH 6/7] resolve generate community report error --- graphrag/llm/ollama/ollama_chat_llm.py | 40 ++++++++++++++++++--- graphrag/llm/ollama/ollama_configuration.py | 2 +- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/graphrag/llm/ollama/ollama_chat_llm.py b/graphrag/llm/ollama/ollama_chat_llm.py index a9ba70596b..5f61b97d76 100644 --- a/graphrag/llm/ollama/ollama_chat_llm.py +++ b/graphrag/llm/ollama/ollama_chat_llm.py @@ -14,6 +14,7 @@ LLMInput, LLMOutput, ) +from graphrag.llm.utils import try_parse_json_object from .ollama_configuration import OllamaConfiguration from .types import OllamaClientType @@ -39,7 +40,6 @@ async def _execute_llm( ) -> CompletionOutput | None: args = { **self.configuration.get_chat_cache_args(), - **(kwargs.get("model_parameters") or {}), } history = kwargs.get("history") or [] messages = [ @@ -52,9 +52,39 @@ async def _execute_llm( return completion["message"]["content"] async def _invoke_json( - self, - input: CompletionInput, - **kwargs: Unpack[LLMInput], + self, + input: CompletionInput, + **kwargs: Unpack[LLMInput], ) -> LLMOutput[CompletionOutput]: """Generate JSON output.""" - pass + name = kwargs.get("name") or "unknown" + is_response_valid = kwargs.get("is_response_valid") or (lambda _x: True) + + async def generate( + attempt: int | None = None, + ) -> LLMOutput[CompletionOutput]: + call_name = name if attempt is None else f"{name}@{attempt}" + result = await self._invoke(input, **{**kwargs, "name": call_name}) + print("output:\n", result) + output, json_output = try_parse_json_object(result.output or "") + + return LLMOutput[CompletionOutput]( + output=output, + json=json_output, + history=result.history, + ) + + def is_valid(x: dict | None) -> bool: + return x is not None and is_response_valid(x) + + result = await generate() + retry = 0 + while not is_valid(result.json) and retry < _MAX_GENERATION_RETRIES: + result = await generate(retry) + retry += 1 + + if is_valid(result.json): + return result + + error_msg = f"{FAILED_TO_CREATE_JSON_ERROR} - Faulty JSON: {result.json!s}" + raise RuntimeError(error_msg) diff --git a/graphrag/llm/ollama/ollama_configuration.py b/graphrag/llm/ollama/ollama_configuration.py index 468b3c5b33..7a0fd9db62 100644 --- a/graphrag/llm/ollama/ollama_configuration.py +++ b/graphrag/llm/ollama/ollama_configuration.py @@ -116,7 +116,7 @@ def lookup_bool(key: str) -> bool | None: self._mirostat = lookup_int("mirostat") self._mirostat_eta = lookup_float("mirostat_eta") self._mirostat_tau = lookup_float("mirostat_tau") - self._num_ctx = lookup_int("num_ctx") + self._num_ctx = lookup_int("max_tokens") self._repeat_last_n = lookup_int("repeat_last_n") self._repeat_penalty = lookup_float("repeat_penalty") self._frequency_penalty = lookup_float("frequency_penalty") From 84be1e5fcd02ec4d1c19e155569923630cfb2818 Mon Sep 17 00:00:00 2001 From: L1u <932044860@qq.com> Date: Sat, 26 Oct 2024 20:45:47 +0800 Subject: [PATCH 7/7] support search by ollama --- graphrag/llm/ollama/create_ollama_client.py | 10 ++- graphrag/llm/ollama/ollama_configuration.py | 23 ++++++ graphrag/llm/ollama/types.py | 4 +- graphrag/query/factories.py | 12 ++- graphrag/query/llm/ollama/__init__.py | 10 +++ graphrag/query/llm/ollama/chat_ollama.py | 88 +++++++++++++++++++++ graphrag/query/llm/ollama/embeding.py | 34 ++++++++ 7 files changed, 175 insertions(+), 6 deletions(-) create mode 100644 graphrag/query/llm/ollama/__init__.py create mode 100644 graphrag/query/llm/ollama/chat_ollama.py create mode 100644 graphrag/query/llm/ollama/embeding.py diff --git a/graphrag/llm/ollama/create_ollama_client.py b/graphrag/llm/ollama/create_ollama_client.py index 87f5902c44..3997adb855 100644 --- a/graphrag/llm/ollama/create_ollama_client.py +++ b/graphrag/llm/ollama/create_ollama_client.py @@ -6,7 +6,7 @@ import logging from functools import cache -from ollama import AsyncClient +from ollama import AsyncClient, Client from .ollama_configuration import OllamaConfiguration from .types import OllamaClientType @@ -18,11 +18,17 @@ @cache def create_ollama_client( - configuration: OllamaConfiguration + configuration: OllamaConfiguration, + sync: bool = False, ) -> OllamaClientType: """Create a new Ollama client instance.""" log.info("Creating OpenAI client base_url=%s", configuration.api_base) + if sync: + return Client( + host=configuration.api_base, + timeout=configuration.request_timeout or 180.0, + ) return AsyncClient( host=configuration.api_base, # Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here diff --git a/graphrag/llm/ollama/ollama_configuration.py b/graphrag/llm/ollama/ollama_configuration.py index 7a0fd9db62..ed613cd523 100644 --- a/graphrag/llm/ollama/ollama_configuration.py +++ b/graphrag/llm/ollama/ollama_configuration.py @@ -47,6 +47,8 @@ class OllamaConfiguration(Hashable, LLMConfig): _keep_alive: int | None _stream: bool | None + # embedding + _truncate: bool | None # Retry Logic _max_retries: int | None @@ -168,6 +170,7 @@ def lookup_bool(key: str) -> bool | None: "raw": self._raw, "keep_alive": self._keep_alive, } + self._truncate = lookup_bool("truncate") @property def api_key(self) -> str: @@ -441,6 +444,15 @@ def options(self) -> dict: } ) + @property + def truncate(self): + """ + truncates the end of each input to fit within context length. + Returns error if false and context length is exceeded. + Defaults to true + """ + return self._truncate + def lookup(self, name: str, default_value: Any = None) -> Any: """Lookup method definition.""" return self._raw_config.get(name, default_value) @@ -474,6 +486,17 @@ def get_chat_cache_args(self) -> dict: } ) + def get_embed_cache_args(self) -> dict: + """Get cache arguments for a embedding LLM.""" + return non_none_value_key( + { + "model": self.model, + "options": self.options, + "keep_alive": self.keep_alive, + "truncate": self.truncate, + } + ) + def __str__(self) -> str: """Str method definition.""" return json.dumps(self.raw_config, indent=4) diff --git a/graphrag/llm/ollama/types.py b/graphrag/llm/ollama/types.py index b3719fce95..12ff1351a5 100644 --- a/graphrag/llm/ollama/types.py +++ b/graphrag/llm/ollama/types.py @@ -3,6 +3,6 @@ """A base class for OpenAI-based LLMs.""" -from ollama import AsyncClient +from ollama import AsyncClient, Client -OllamaClientType = AsyncClient +OllamaClientType = AsyncClient | Client diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 3ad520bfbc..583115b187 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -10,6 +10,7 @@ GraphRagConfig, LLMType, ) +from graphrag.llm import OllamaConfiguration from graphrag.model import ( CommunityReport, Covariate, @@ -18,9 +19,12 @@ TextUnit, ) from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey +from graphrag.query.llm.base import BaseLLM, BaseTextEmbedding from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType +from graphrag.query.llm.ollama.chat_ollama import ChatOllama +from graphrag.query.llm.ollama.embeding import OllamaEmbedding from graphrag.query.structured_search.global_search.community_context import ( GlobalCommunityContext, ) @@ -32,8 +36,10 @@ from graphrag.vector_stores import BaseVectorStore -def get_llm(config: GraphRagConfig) -> ChatOpenAI: +def get_llm(config: GraphRagConfig) -> BaseLLM: """Get the LLM client.""" + if config.llm.type in (LLMType.Ollama, LLMType.OllamaChat): + return ChatOllama(OllamaConfiguration(dict(config.llm))) is_azure_client = ( config.llm.type == LLMType.AzureOpenAIChat or config.llm.type == LLMType.AzureOpenAI @@ -67,8 +73,10 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI: ) -def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding: +def get_text_embedder(config: GraphRagConfig) -> BaseTextEmbedding: """Get the LLM client for embeddings.""" + if config.embeddings.llm.type == LLMType.OllamaEmbedding: + return OllamaEmbedding(OllamaConfiguration(dict(config.embeddings.llm))) is_azure_client = config.embeddings.llm.type == LLMType.AzureOpenAIEmbedding debug_embedding_api_key = config.embeddings.llm.api_key or "" llm_debug_info = { diff --git a/graphrag/query/llm/ollama/__init__.py b/graphrag/query/llm/ollama/__init__.py new file mode 100644 index 0000000000..1d08741ef9 --- /dev/null +++ b/graphrag/query/llm/ollama/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from .chat_ollama import ChatOllama +from .embeding import OllamaEmbedding + +__all__ = [ + "ChatOllama", + "OllamaEmbedding", +] diff --git a/graphrag/query/llm/ollama/chat_ollama.py b/graphrag/query/llm/ollama/chat_ollama.py new file mode 100644 index 0000000000..ff29264040 --- /dev/null +++ b/graphrag/query/llm/ollama/chat_ollama.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Chat-based Ollama LLM implementation.""" +from typing import Any, AsyncGenerator, Generator + +from tenacity import ( + AsyncRetrying, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) + +from graphrag.callbacks.llm_callbacks import BaseLLMCallback +from graphrag.llm import OllamaConfiguration, create_ollama_client +from graphrag.query.llm.base import BaseLLM + + +class ChatOllama(BaseLLM): + """Wrapper for Ollama ChatCompletion models.""" + + def __init__(self, configuration: OllamaConfiguration): + self.configuration = configuration + self.sync_client = create_ollama_client(configuration, sync=True) + self.async_client = create_ollama_client(configuration) + + def generate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate a response.""" + response = self.sync_client.chat( + messages, + **self.configuration.get_chat_cache_args(), + ) + return response["message"]["content"] + + def stream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> Generator[str, None, None]: + """Generate a response with streaming.""" + + async def agenerate( + self, + messages: str | list[Any], + streaming: bool = True, + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> str: + """Generate a response asynchronously.""" + """Generate text asynchronously.""" + try: + retryer = AsyncRetrying( + stop=stop_after_attempt(self.configuration.max_retries), + wait=wait_exponential_jitter(max=10), + reraise=True, + retry=retry_if_exception_type(Exception), # type: ignore + ) + async for attempt in retryer: + with attempt: + response = await self.async_client.chat( + messages=messages, + **{ + **self.configuration.get_chat_cache_args(), + "stream": False, + } + ) + return response["message"]["content"] + except Exception as e: + raise e + + async def astream_generate( + self, + messages: str | list[Any], + callbacks: list[BaseLLMCallback] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + """Generate a response asynchronously with streaming.""" + ... + diff --git a/graphrag/query/llm/ollama/embeding.py b/graphrag/query/llm/ollama/embeding.py new file mode 100644 index 0000000000..90b4f29d32 --- /dev/null +++ b/graphrag/query/llm/ollama/embeding.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Ollama Embedding model implementation.""" + +from typing import Any + +from graphrag.llm import OllamaConfiguration, create_ollama_client +from graphrag.query.llm.base import BaseTextEmbedding + + +class OllamaEmbedding(BaseTextEmbedding): + """Wrapper for Ollama Embedding models.""" + + def __init__(self, configuration: OllamaConfiguration): + self.configuration = configuration + self.sync_client = create_ollama_client(configuration, sync=True) + self.async_client = create_ollama_client(configuration) + + def embed(self, text: str, **kwargs: Any) -> list[float]: + """Embed a text string.""" + response = self.sync_client.embed( + input=text, + **self.configuration.get_embed_cache_args(), + ) + return response["embeddings"][0] + + async def aembed(self, text: str, **kwargs: Any) -> list[float]: + """Embed a text string asynchronously.""" + response = await self.async_client.embed( + input=text, + **self.configuration.get_embed_cache_args(), + ) + return response["embeddings"][0]