Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ollama #1326

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,17 @@ class LLMType(str, Enum):
# Embeddings
OpenAIEmbedding = "openai_embedding"
AzureOpenAIEmbedding = "azure_openai_embedding"
OllamaEmbedding = "ollama_embedding"

# Raw Completion
OpenAI = "openai"
AzureOpenAI = "azure_openai"
Ollama = "ollama"

# Chat Completion
OpenAIChat = "openai_chat"
AzureOpenAIChat = "azure_openai_chat"
OllamaChat = "ollama_chat"

# Debug
StaticResponse = "static_response"
Expand Down
109 changes: 106 additions & 3 deletions graphrag/index/llm/load_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
create_openai_completion_llm,
create_openai_embedding_llm,
create_tpm_rpm_limiters,
OllamaConfiguration,
create_ollama_client,
create_ollama_chat_llm,
create_ollama_embedding_llm,
create_ollama_completion_llm,
LLMConfig,
)

if TYPE_CHECKING:
Expand All @@ -46,7 +52,6 @@ def load_llm(
) -> CompletionLLM:
"""Load the LLM for the entity extraction chain."""
on_error = _create_error_handler(callbacks)

if llm_type in loaders:
if chat_only and not loaders[llm_type]["chat"]:
msg = f"LLM type {llm_type} does not support chat"
Expand Down Expand Up @@ -182,6 +187,50 @@ def _load_azure_openai_embeddings_llm(
return _load_openai_embeddings_llm(on_error, cache, config, True)


def _load_ollama_completion_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
):
return _create_ollama_completion_llm(
OllamaConfiguration({
**_get_base_config(config),
}),
on_error,
cache,
)


def _load_ollama_chat_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
):
return _create_ollama_chat_llm(
OllamaConfiguration({
# Set default values
**_get_base_config(config),
}),
on_error,
cache,
)


def _load_ollama_embeddings_llm(
on_error: ErrorHandlerFn,
cache: LLMCache,
config: dict[str, Any],
):
# TODO: Inject Cache
return _create_ollama_embeddings_llm(
OllamaConfiguration({
**_get_base_config(config),
}),
on_error,
cache,
)


def _get_base_config(config: dict[str, Any]) -> dict[str, Any]:
api_key = config.get("api_key")

Expand Down Expand Up @@ -218,6 +267,10 @@ def _load_static_response(
"load": _load_azure_openai_completion_llm,
"chat": False,
},
LLMType.Ollama: {
"load": _load_ollama_completion_llm,
"chat": False,
},
LLMType.OpenAIChat: {
"load": _load_openai_chat_llm,
"chat": True,
Expand All @@ -226,6 +279,10 @@ def _load_static_response(
"load": _load_azure_openai_chat_llm,
"chat": True,
},
LLMType.OllamaChat: {
"load": _load_ollama_chat_llm,
"chat": True,
},
LLMType.OpenAIEmbedding: {
"load": _load_openai_embeddings_llm,
"chat": False,
Expand All @@ -234,6 +291,10 @@ def _load_static_response(
"load": _load_azure_openai_embeddings_llm,
"chat": False,
},
LLMType.OllamaEmbedding: {
"load": _load_ollama_embeddings_llm,
"chat": False,
},
LLMType.StaticResponse: {
"load": _load_static_response,
"chat": False,
Expand Down Expand Up @@ -286,7 +347,49 @@ def _create_openai_embeddings_llm(
)


def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter:
def _create_ollama_chat_llm(
configuration: OllamaConfiguration,
on_error: ErrorHandlerFn,
cache: LLMCache,
) -> CompletionLLM:
"""Create an Ollama chat llm."""
client = create_ollama_client(configuration=configuration)
limiter = _create_limiter(configuration)
semaphore = _create_semaphore(configuration)
return create_ollama_chat_llm(
client, configuration, cache, limiter, semaphore, on_error=on_error
)


def _create_ollama_completion_llm(
configuration: OllamaConfiguration,
on_error: ErrorHandlerFn,
cache: LLMCache,
) -> CompletionLLM:
"""Create an Ollama completion llm."""
client = create_ollama_client(configuration=configuration)
limiter = _create_limiter(configuration)
semaphore = _create_semaphore(configuration)
return create_ollama_completion_llm(
client, configuration, cache, limiter, semaphore, on_error=on_error
)


def _create_ollama_embeddings_llm(
configuration: OllamaConfiguration,
on_error: ErrorHandlerFn,
cache: LLMCache,
) -> EmbeddingLLM:
"""Create an Ollama embeddings llm."""
client = create_ollama_client(configuration=configuration)
limiter = _create_limiter(configuration)
semaphore = _create_semaphore(configuration)
return create_ollama_embedding_llm(
client, configuration, cache, limiter, semaphore, on_error=on_error
)


def _create_limiter(configuration: LLMConfig) -> LLMLimiter:
limit_name = configuration.model or configuration.deployment_name or "default"
if limit_name not in _rate_limiters:
tpm = configuration.tokens_per_minute
Expand All @@ -296,7 +399,7 @@ def _create_limiter(configuration: OpenAIConfiguration) -> LLMLimiter:
return _rate_limiters[limit_name]


def _create_semaphore(configuration: OpenAIConfiguration) -> asyncio.Semaphore | None:
def _create_semaphore(configuration: LLMConfig) -> asyncio.Semaphore | None:
limit_name = configuration.model or configuration.deployment_name or "default"
concurrency = configuration.concurrent_requests

Expand Down
21 changes: 21 additions & 0 deletions graphrag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@
LLMOutput,
OnCacheActionFn,
)
from .ollama import (
OllamaChatLLM,
OllamaClientType,
OllamaCompletionLLM,
OllamaConfiguration,
OllamaEmbeddingsLLM,
create_ollama_chat_llm,
create_ollama_client,
create_ollama_completion_llm,
create_ollama_embedding_llm,
)

__all__ = [
# LLM Types
Expand Down Expand Up @@ -79,13 +90,23 @@
"OpenAIConfiguration",
"OpenAIEmbeddingsLLM",
"RateLimitingLLM",
# Ollama
"OllamaChatLLM",
"OllamaClientType",
"OllamaCompletionLLM",
"OllamaConfiguration",
"OllamaEmbeddingsLLM",
# Errors
"RetriesExhaustedError",
"TpmRpmLLMLimiter",
"create_openai_chat_llm",
"create_openai_client",
"create_openai_completion_llm",
"create_openai_embedding_llm",
"create_ollama_chat_llm",
"create_ollama_client",
"create_ollama_completion_llm",
"create_ollama_embedding_llm",
# Limiters
"create_tpm_rpm_limiters",
]
29 changes: 29 additions & 0 deletions graphrag/llm/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Ollama LLM implementations."""

from .create_ollama_client import create_ollama_client
from .factories import (
create_ollama_chat_llm,
create_ollama_completion_llm,
create_ollama_embedding_llm,
)
from .ollama_chat_llm import OllamaChatLLM
from .ollama_completion_llm import OllamaCompletionLLM
from .ollama_configuration import OllamaConfiguration
from .ollama_embeddings_llm import OllamaEmbeddingsLLM
from .types import OllamaClientType


__all__ = [
"OllamaChatLLM",
"OllamaClientType",
"OllamaCompletionLLM",
"OllamaConfiguration",
"OllamaEmbeddingsLLM",
"create_ollama_chat_llm",
"create_ollama_client",
"create_ollama_completion_llm",
"create_ollama_embedding_llm",
]
36 changes: 36 additions & 0 deletions graphrag/llm/ollama/create_ollama_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Create OpenAI client instance."""

import logging
from functools import cache

from ollama import AsyncClient, Client

from .ollama_configuration import OllamaConfiguration
from .types import OllamaClientType

log = logging.getLogger(__name__)

API_BASE_REQUIRED_FOR_AZURE = "api_base is required for Azure OpenAI client"


@cache
def create_ollama_client(
configuration: OllamaConfiguration,
sync: bool = False,
) -> OllamaClientType:
"""Create a new Ollama client instance."""

log.info("Creating OpenAI client base_url=%s", configuration.api_base)
if sync:
return Client(
host=configuration.api_base,
timeout=configuration.request_timeout or 180.0,
)
return AsyncClient(
host=configuration.api_base,
# Timeout/Retry Configuration - Use Tenacity for Retries, so disable them here
timeout=configuration.request_timeout or 180.0,
)
Loading