Skip to content

Commit

Permalink
Add from_file class method to the Prompt object
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Jan 6, 2025
1 parent 3cc399d commit 8b27da3
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 50 deletions.
4 changes: 3 additions & 1 deletion outlines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Outlines is a Generative Model Programming Framework."""

import outlines.generate
import outlines.grammars
import outlines.models
Expand All @@ -7,14 +8,15 @@
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.prompts import prompt
from outlines.prompts import Prompt, prompt

__all__ = [
"clear_cache",
"disable_cache",
"get_cache",
"Function",
"prompt",
"Prompt",
"vectorize",
"grammars",
]
118 changes: 89 additions & 29 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import functools
import inspect
import json
import os
import re
import textwrap
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Type, cast
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union, cast

from jinja2 import Environment, StrictUndefined
from pydantic import BaseModel
import jinja2
import pydantic


@dataclass
Expand All @@ -19,12 +22,8 @@ class Prompt:
"""

template: str
signature: inspect.Signature

def __post_init__(self):
self.parameters: List[str] = list(self.signature.parameters.keys())
self.jinja_environment = create_jinja_template(self.template)
template: jinja2.Template
signature: Optional[inspect.Signature]

def __call__(self, *args, **kwargs) -> str:
"""Render and return the template.
Expand All @@ -34,12 +33,50 @@ def __call__(self, *args, **kwargs) -> str:
The rendered template as a Python ``str``.
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.jinja_environment.render(**bound_arguments.arguments)
if self.signature is not None:
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.template.render(**bound_arguments.arguments)
else:
return self.template.render(**kwargs)

@classmethod
def from_file(cls, path: Path):
"""
Create a Prompt instance from a file containing a Jinja template.
Note: This method does not allow to include and inheritance to reference files
that are outside the folder or subfolders of the file given to `from_file`.
Parameters
----------
path : Path
The path to the file containing the Jinja template.
Returns
-------
Prompt
An instance of the Prompt class with the template loaded from the file.
"""
# We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is
# split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance)
return cls(template_from_file(path), None)

def __str__(self):
return self.template
@classmethod
def from_str(cls, content: str):
"""
Create an instance of the class from a string.
Parameters
----------
content : str
The string content to be converted into a template.
Returns
-------
An instance of the class with the provided content as a template.
"""
return cls(template_from_str(content), None)


def prompt(fn: Callable) -> Prompt:
Expand Down Expand Up @@ -87,12 +124,14 @@ def prompt(fn: Callable) -> Prompt:
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = cast(str, docstring)
template = template_from_str(cast(str, docstring))

return Prompt(template, signature)


def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
def render(
template: Union[jinja2.Template, str], **values: Optional[Dict[str, Any]]
) -> str:
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
This function removes extra whitespaces and linebreaks from templates to
Expand Down Expand Up @@ -174,7 +213,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
Parameters
----------
template
A string that contains a template written with the Jinja2 syntax.
A Jinja2 template.
**values
Map from the variables in the template to their value.
Expand All @@ -183,17 +222,27 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
A string that contains the rendered template.
"""
jinja_template = create_jinja_template(template)
return jinja_template.render(**values)
match template:
case jinja2.Template():
return template.render(**values)
case str():
warnings.warn(
"Using `render(str)` is deprecated. Please use a `render(Prompt.from_str(str))` instead.",
DeprecationWarning,
)
template = template_from_str(template)
return template.render(**values)
case _:
raise AssertionError("This should never happen")


def create_jinja_template(template: str):
def template_from_str(content: str) -> jinja2.Template:
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(template)
cleaned_template = inspect.cleandoc(content)

# Add linebreak if there were any extra linebreaks that
# `cleandoc` would have removed
ends_with_linebreak = template.replace(" ", "").endswith("\n\n")
ends_with_linebreak = content.replace(" ", "").endswith("\n\n")
if ends_with_linebreak:
cleaned_template += "\n"

Expand All @@ -202,11 +251,11 @@ def create_jinja_template(template: str):
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

env = Environment(
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=StrictUndefined,
undefined=jinja2.StrictUndefined,
)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
Expand All @@ -215,8 +264,19 @@ def create_jinja_template(template: str):
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

jinja_template = env.from_string(cleaned_template)
return jinja_template
return env.from_string(cleaned_template)


def template_from_file(path: Path) -> jinja2.Template:
file_directory = os.path.dirname(os.path.abspath(path))
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(file_directory),
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
return env.get_template(os.path.basename(path))


def get_fn_name(fn: Callable):
Expand Down Expand Up @@ -301,10 +361,10 @@ def get_schema_dict(model: Dict):
return json.dumps(model, indent=2)


@get_schema.register(type(BaseModel))
def get_schema_pydantic(model: Type[BaseModel]):
@get_schema.register(type(pydantic.BaseModel))
def get_schema_pydantic(model: Type[pydantic.BaseModel]):
"""Return the schema of a Pydantic model."""
if not type(model) == type(BaseModel):
if not isinstance(model, type(pydantic.BaseModel)):
raise TypeError("The `schema` filter only applies to Pydantic models.")

if hasattr(model, "model_json_schema"):
Expand Down
Loading

0 comments on commit 8b27da3

Please sign in to comment.