From 8b27da3d137c8f88a4436733fc2a3105e051e549 Mon Sep 17 00:00:00 2001 From: Yvan Sraka Date: Thu, 2 Jan 2025 14:50:00 +0100 Subject: [PATCH] Add `from_file` class method to the `Prompt` object --- outlines/__init__.py | 4 +- outlines/prompts.py | 118 ++++++++++++++++++++++++++++++---------- tests/test_prompts.py | 122 +++++++++++++++++++++++++++++++++++------- 3 files changed, 194 insertions(+), 50 deletions(-) diff --git a/outlines/__init__.py b/outlines/__init__.py index 307d2ba6f..eeba78dc7 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -1,4 +1,5 @@ """Outlines is a Generative Model Programming Framework.""" + import outlines.generate import outlines.grammars import outlines.models @@ -7,7 +8,7 @@ from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function -from outlines.prompts import prompt +from outlines.prompts import Prompt, prompt __all__ = [ "clear_cache", @@ -15,6 +16,7 @@ "get_cache", "Function", "prompt", + "Prompt", "vectorize", "grammars", ] diff --git a/outlines/prompts.py b/outlines/prompts.py index a7824451a..b73738be7 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -1,13 +1,16 @@ import functools import inspect import json +import os import re import textwrap +import warnings from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Type, cast +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Type, Union, cast -from jinja2 import Environment, StrictUndefined -from pydantic import BaseModel +import jinja2 +import pydantic @dataclass @@ -19,12 +22,8 @@ class Prompt: """ - template: str - signature: inspect.Signature - - def __post_init__(self): - self.parameters: List[str] = list(self.signature.parameters.keys()) - self.jinja_environment = create_jinja_template(self.template) + template: jinja2.Template + signature: Optional[inspect.Signature] def __call__(self, *args, **kwargs) -> str: """Render and return the template. @@ -34,12 +33,50 @@ def __call__(self, *args, **kwargs) -> str: The rendered template as a Python ``str``. """ - bound_arguments = self.signature.bind(*args, **kwargs) - bound_arguments.apply_defaults() - return self.jinja_environment.render(**bound_arguments.arguments) + if self.signature is not None: + bound_arguments = self.signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return self.template.render(**bound_arguments.arguments) + else: + return self.template.render(**kwargs) + + @classmethod + def from_file(cls, path: Path): + """ + Create a Prompt instance from a file containing a Jinja template. + + Note: This method does not allow to include and inheritance to reference files + that are outside the folder or subfolders of the file given to `from_file`. + + Parameters + ---------- + path : Path + The path to the file containing the Jinja template. + + Returns + ------- + Prompt + An instance of the Prompt class with the template loaded from the file. + """ + # We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is + # split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance) + return cls(template_from_file(path), None) - def __str__(self): - return self.template + @classmethod + def from_str(cls, content: str): + """ + Create an instance of the class from a string. + + Parameters + ---------- + content : str + The string content to be converted into a template. + + Returns + ------- + An instance of the class with the provided content as a template. + """ + return cls(template_from_str(content), None) def prompt(fn: Callable) -> Prompt: @@ -87,12 +124,14 @@ def prompt(fn: Callable) -> Prompt: if docstring is None: raise TypeError("Could not find a template in the function's docstring.") - template = cast(str, docstring) + template = template_from_str(cast(str, docstring)) return Prompt(template, signature) -def render(template: str, **values: Optional[Dict[str, Any]]) -> str: +def render( + template: Union[jinja2.Template, str], **values: Optional[Dict[str, Any]] +) -> str: r"""Parse a Jinaj2 template and translate it into an Outlines graph. This function removes extra whitespaces and linebreaks from templates to @@ -174,7 +213,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: Parameters ---------- template - A string that contains a template written with the Jinja2 syntax. + A Jinja2 template. **values Map from the variables in the template to their value. @@ -183,17 +222,27 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: A string that contains the rendered template. """ - jinja_template = create_jinja_template(template) - return jinja_template.render(**values) + match template: + case jinja2.Template(): + return template.render(**values) + case str(): + warnings.warn( + "Using `render(str)` is deprecated. Please use a `render(Prompt.from_str(str))` instead.", + DeprecationWarning, + ) + template = template_from_str(template) + return template.render(**values) + case _: + raise AssertionError("This should never happen") -def create_jinja_template(template: str): +def template_from_str(content: str) -> jinja2.Template: # Dedent, and remove extra linebreak - cleaned_template = inspect.cleandoc(template) + cleaned_template = inspect.cleandoc(content) # Add linebreak if there were any extra linebreaks that # `cleandoc` would have removed - ends_with_linebreak = template.replace(" ", "").endswith("\n\n") + ends_with_linebreak = content.replace(" ", "").endswith("\n\n") if ends_with_linebreak: cleaned_template += "\n" @@ -202,11 +251,11 @@ def create_jinja_template(template: str): # used to continue to the next line without linebreak. cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) - env = Environment( + env = jinja2.Environment( trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True, - undefined=StrictUndefined, + undefined=jinja2.StrictUndefined, ) env.filters["name"] = get_fn_name env.filters["description"] = get_fn_description @@ -215,8 +264,19 @@ def create_jinja_template(template: str): env.filters["schema"] = get_schema env.filters["args"] = get_fn_args - jinja_template = env.from_string(cleaned_template) - return jinja_template + return env.from_string(cleaned_template) + + +def template_from_file(path: Path) -> jinja2.Template: + file_directory = os.path.dirname(os.path.abspath(path)) + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(file_directory), + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=jinja2.StrictUndefined, + ) + return env.get_template(os.path.basename(path)) def get_fn_name(fn: Callable): @@ -301,10 +361,10 @@ def get_schema_dict(model: Dict): return json.dumps(model, indent=2) -@get_schema.register(type(BaseModel)) -def get_schema_pydantic(model: Type[BaseModel]): +@get_schema.register(type(pydantic.BaseModel)) +def get_schema_pydantic(model: Type[pydantic.BaseModel]): """Return the schema of a Pydantic model.""" - if not type(model) == type(BaseModel): + if not isinstance(model, type(pydantic.BaseModel)): raise TypeError("The `schema` filter only applies to Pydantic models.") if hasattr(model, "model_json_schema"): diff --git a/tests/test_prompts.py b/tests/test_prompts.py index a0433c0e5..a895744fd 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,56 +1,73 @@ +import os +import tempfile from typing import Dict, List import pytest from pydantic import BaseModel, Field import outlines -from outlines.prompts import render +from outlines.prompts import Prompt, render, template_from_str def test_render(): - tpl = """ + tpl = template_from_str( + """ A test string""" + ) assert render(tpl) == "A test string" - tpl = """ + tpl = template_from_str( + """ A test string """ + ) assert render(tpl) == "A test string" - tpl = """ + tpl = template_from_str( + """ A test Another test """ + ) assert render(tpl) == "A test\nAnother test" - tpl = """A test + tpl = template_from_str( + """A test Another test """ + ) assert render(tpl) == "A test\nAnother test" - tpl = """ + tpl = template_from_str( + """ A test line An indented line """ + ) assert render(tpl) == "A test line\n An indented line" - tpl = """ + tpl = template_from_str( + """ A test line An indented line """ + ) assert render(tpl) == "A test line\n An indented line\n" def test_render_escaped_linebreak(): - tpl = """ + tpl = template_from_str( + """ A long test \ that we break \ in several lines """ + ) assert render(tpl) == "A long test that we break in several lines" - tpl = """ + tpl = template_from_str( + """ Break in \ several lines \ But respect the indentation @@ -58,6 +75,7 @@ def test_render_escaped_linebreak(): And after everything \ Goes back to normal """ + ) assert ( render(tpl) == "Break in several lines But respect the indentation\n on line breaks.\nAnd after everything Goes back to normal" @@ -72,10 +90,12 @@ def test_render_jinja(): # Notice the newline after the end of the loop examples = ["one", "two"] prompt = render( - """ + template_from_str( + """ {% for e in examples %} Example: {{e}} - {% endfor -%}""", + {% endfor -%}""" + ), examples=examples, ) assert prompt == "Example: one\nExample: two\n" @@ -83,24 +103,28 @@ def test_render_jinja(): # We can remove the newline by cloing with -%} examples = ["one", "two"] prompt = render( - """ + template_from_str( + """ {% for e in examples %} Example: {{e}} {% endfor -%} - Final""", + Final""" + ), examples=examples, ) assert prompt == "Example: one\nExample: two\nFinal" # Same for conditionals - tpl = """ + tpl = template_from_str( + """ {% if is_true %} true {% endif -%} final """ + ) assert render(tpl, is_true=True) == "true\nfinal" assert render(tpl, is_true=False) == "final" @@ -110,9 +134,6 @@ def test_prompt_basic(): def test_tpl(variable): """{{variable}} test""" - assert test_tpl.template == "{{variable}} test" - assert test_tpl.parameters == ["variable"] - with pytest.raises(TypeError): test_tpl(v="test") @@ -135,9 +156,6 @@ def test_prompt_kwargs(): def test_kwarg_tpl(var, other_var="other"): """{{var}} and {{other_var}}""" - assert test_kwarg_tpl.template == "{{var}} and {{other_var}}" - assert test_kwarg_tpl.parameters == ["var", "other_var"] - p = test_kwarg_tpl("test") assert p == "test and other" @@ -312,3 +330,67 @@ def args_prompt(fn): args_prompt(with_all) == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" ) + + +@pytest.fixture +def temp_prompt_file(): + test_dir = tempfile.mkdtemp() + + base_template_path = os.path.join(test_dir, "base_template.txt") + with open(base_template_path, "w") as f: + f.write( + """{% block content %}{% endblock %} +""" + ) + + include_file_path = os.path.join(test_dir, "include.txt") + with open(include_file_path, "w") as f: + f.write( + """{% for example in examples %} +- Q: {{ example.question }} +- A: {{ example.answer }} +{% endfor %} +""" + ) + + prompt_file_path = os.path.join(test_dir, "prompt.txt") + with open(prompt_file_path, "w") as f: + f.write( + """{% extends "base_template.txt" %} + +{% block content %} +Here is a prompt with examples: + +{% include "include.txt" %} + +Now please answer the following question: + +Q: {{ question }} +A: +{% endblock %} +""" + ) + yield prompt_file_path + + +def test_prompt_from_file(temp_prompt_file): + prompt = Prompt.from_file(temp_prompt_file) + examples = [ + {"question": "What is the capital of France?", "answer": "Paris"}, + {"question": "What is 2 + 2?", "answer": "4"}, + ] + question = "What is the Earth's diameter?" + rendered = prompt(examples=examples, question=question) + expected = """Here is a prompt with examples: + +- Q: What is the capital of France? +- A: Paris +- Q: What is 2 + 2? +- A: 4 + +Now please answer the following question: + +Q: What is the Earth's diameter? +A: +""" + assert rendered.strip() == expected.strip()