Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement prompt tuning API #855

Merged
merged 11 commits into from
Aug 12, 2024
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20240807023736041951.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Implement auto templating API."
}
82 changes: 39 additions & 43 deletions graphrag/prompt_tune/__main__.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,48 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""The Prompt auto templating package root."""
"""The auto templating package root."""

import argparse
import asyncio
from enum import Enum

from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE

from .api import DocSelectionType
from .cli import prompt_tune


class DocSelectionType(Enum):
"""The type of document selection to use."""

ALL = "all"
RANDOM = "random"
TOP = "top"
AUTO = "auto"

def __str__(self):
"""Return the string representation of the enum value."""
return self.value

from .generator import MAX_TOKEN_COUNT
from .loader import MIN_CHUNK_SIZE

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
prog="python -m graphrag.prompt_tune",
description="The graphrag auto templating module.",
)

parser.add_argument(
"--config",
help="Configuration yaml file to use when generating prompts",
required=True,
type=str,
)

parser.add_argument(
"--root",
help="The data project root. Including the config yml, json or .env",
help="Data project root. Default: current directory",
required=False,
type=str,
default=".",
)

parser.add_argument(
"--domain",
help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.",
help="Domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, the domain will be inferred from the input data.",
required=False,
default="",
type=str,
)

parser.add_argument(
"--method",
help="The method to select documents, one of: all, random, top or auto",
"--selection-method",
help=f"Chunk selection method. Default: {DocSelectionType.RANDOM}",
required=False,
type=DocSelectionType,
choices=list(DocSelectionType),
Expand All @@ -56,47 +51,47 @@ def __str__(self):

parser.add_argument(
"--n_subset_max",
help="The number of text chunks to embed when using auto selection method",
help="Number of text chunks to embed when using auto selection method. Default: 300",
required=False,
type=int,
default=300,
)

parser.add_argument(
"--k",
help="The maximum number of documents to select from each centroid when using auto selection method",
help="Maximum number of documents to select from each centroid when using auto selection method. Default: 15",
required=False,
type=int,
default=15,
)

parser.add_argument(
"--limit",
help="The limit of files to load when doing random or top selection",
help="Number of documents to load when doing random or top selection. Default: 15",
type=int,
required=False,
default=15,
)

parser.add_argument(
"--max-tokens",
help="Max token count for prompt generation",
help=f"Max token count for prompt generation. Default: {MAX_TOKEN_COUNT}",
type=int,
required=False,
default=MAX_TOKEN_COUNT,
)

parser.add_argument(
"--min-examples-required",
help="The minimum number of examples required in entity extraction prompt",
help="Minimum number of examples required in the entity extraction prompt. Default: 2",
type=int,
required=False,
default=2,
)

parser.add_argument(
"--chunk-size",
help="Max token count for prompt generation",
help=f"Max token count for prompt generation. Default: {MIN_CHUNK_SIZE}",
type=int,
required=False,
default=MIN_CHUNK_SIZE,
Expand All @@ -120,7 +115,7 @@ def __str__(self):

parser.add_argument(
"--output",
help="Folder to save the generated prompts to",
help="Directory to save generated prompts to. Default: 'prompts'",
type=str,
required=False,
default="prompts",
Expand All @@ -132,17 +127,18 @@ def __str__(self):

loop.run_until_complete(
prompt_tune(
args.root,
args.domain,
str(args.method),
args.limit,
args.max_tokens,
args.chunk_size,
args.language,
args.no_entity_types,
args.output,
args.n_subset_max,
args.k,
args.min_examples_required,
config=args.config,
root=args.root,
domain=args.domain,
selection_method=args.selection_method,
limit=args.limit,
max_tokens=args.max_tokens,
chunk_size=args.chunk_size,
language=args.language,
skip_entity_types=args.no_entity_types,
output=args.output,
n_subset_max=args.n_subset_max,
k=args.k,
min_examples_required=args.min_examples_required,
)
)
173 changes: 173 additions & 0 deletions graphrag/prompt_tune/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""
Auto Templating API.

This API provides access to the auto templating feature of graphrag, allowing external applications
to hook into graphrag and generate prompts from private data.

WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""

from datashaper import NoopVerbCallbacks
from pydantic import PositiveInt, validate_call

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter

from .cli import DocSelectionType
from .generator import (
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
create_entity_summarization_prompt,
detect_language,
generate_community_report_rating,
generate_community_reporter_role,
generate_domain,
generate_entity_relationship_examples,
generate_entity_types,
generate_persona,
)
from .loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
)


@validate_call
async def generate_indexing_prompts(
config: GraphRagConfig,
root: str,
chunk_size: PositiveInt = MIN_CHUNK_SIZE,
limit: PositiveInt = 15,
selection_method: DocSelectionType = DocSelectionType.RANDOM,
domain: str | None = None,
language: str | None = None,
max_tokens: int = MAX_TOKEN_COUNT,
skip_entity_types: bool = False,
min_examples_required: PositiveInt = 2,
n_subset_max: PositiveInt = 300,
k: PositiveInt = 15,
) -> tuple[str, str, str]:
"""Generate indexing prompts.

Parameters
----------
- config: The GraphRag configuration.
- output_path: The path to store the prompts.
- chunk_size: The chunk token size to use for input text units.
- limit: The limit of chunks to load.
- selection_method: The chunk selection method.
- domain: The domain to map the input documents to.
- language: The language to use for the prompts.
- max_tokens: The maximum number of tokens to use on entity extraction prompts
- skip_entity_types: Skip generating entity types.
- min_examples_required: The minimum number of examples required for entity extraction prompts.
- n_subset_max: The number of text chunks to embed when using auto selection method.
- k: The number of documents to select when using auto selection method.

Returns
-------
tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt
"""
reporter = PrintProgressReporter("")

# Retrieve documents
doc_list = await load_docs_in_chunks(
root=root,
config=config,
limit=limit,
select_method=selection_method,
reporter=reporter,
chunk_size=chunk_size,
n_subset_max=n_subset_max,
k=k,
)

# Create LLM from config
llm = load_llm(
"prompt_tuning",
config.llm.type,
NoopVerbCallbacks(),
None,
config.llm.model_dump(),
)

if not domain:
reporter.info("Generating domain...")
domain = await generate_domain(llm, doc_list)
reporter.info(f"Generated domain: {domain}")

if not language:
reporter.info("Detecting language...")
language = await detect_language(llm, doc_list)

reporter.info("Generating persona...")
persona = await generate_persona(llm, domain)

reporter.info("Generating community report ranking description...")
community_report_ranking = await generate_community_report_rating(
llm, domain=domain, persona=persona, docs=doc_list
)

entity_types = None
if not skip_entity_types:
reporter.info("Generating entity types...")
entity_types = await generate_entity_types(
llm,
domain=domain,
persona=persona,
docs=doc_list,
json_mode=config.llm.model_supports_json or False,
)

reporter.info("Generating entity relationship examples...")
examples = await generate_entity_relationship_examples(
llm,
persona=persona,
entity_types=entity_types,
docs=doc_list,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine
)

reporter.info("Generating entity extraction prompt...")
entity_extraction_prompt = create_entity_extraction_prompt(
entity_types=entity_types,
docs=doc_list,
examples=examples,
language=language,
json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json by the index engine
encoding_model=config.encoding_model,
max_token_count=max_tokens,
min_examples_required=min_examples_required,
)

reporter.info("Generating entity summarization prompt...")
entity_summarization_prompt = create_entity_summarization_prompt(
persona=persona,
language=language,
)

reporter.info("Generating community reporter role...")
community_reporter_role = await generate_community_reporter_role(
llm, domain=domain, persona=persona, docs=doc_list
)

reporter.info("Generating community summarization prompt...")
community_summarization_prompt = create_community_summarization_prompt(
persona=persona,
role=community_reporter_role,
report_rating_description=community_report_ranking,
language=language,
)

return (
entity_extraction_prompt,
entity_summarization_prompt,
community_summarization_prompt,
)
Loading
Loading