Skip to content

Commit

Permalink
Add from_file class method to the Prompt object
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Jan 7, 2025
1 parent 3cc399d commit 32587ee
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 160 deletions.
4 changes: 3 additions & 1 deletion outlines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Outlines is a Generative Model Programming Framework."""

import outlines.generate
import outlines.grammars
import outlines.models
Expand All @@ -7,14 +8,15 @@
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.prompts import prompt
from outlines.prompts import Prompt, prompt

__all__ = [
"clear_cache",
"disable_cache",
"get_cache",
"Function",
"prompt",
"Prompt",
"vectorize",
"grammars",
]
166 changes: 111 additions & 55 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import functools
import inspect
import json
import os
import re
import textwrap
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Type, cast
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, cast

from jinja2 import Environment, StrictUndefined
from pydantic import BaseModel
import jinja2
import pydantic


@dataclass
Expand All @@ -19,12 +22,8 @@ class Prompt:
"""

template: str
signature: inspect.Signature

def __post_init__(self):
self.parameters: List[str] = list(self.signature.parameters.keys())
self.jinja_environment = create_jinja_template(self.template)
template: jinja2.Template
signature: Optional[inspect.Signature]

def __call__(self, *args, **kwargs) -> str:
"""Render and return the template.
Expand All @@ -34,12 +33,93 @@ def __call__(self, *args, **kwargs) -> str:
The rendered template as a Python ``str``.
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.jinja_environment.render(**bound_arguments.arguments)
if self.signature is not None:
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.template.render(**bound_arguments.arguments)
else:
return self.template.render(**kwargs)

@classmethod
def from_str(cls, content: str):
"""
Create an instance of the class from a string.
Parameters
----------
content : str
The string content to be converted into a template.
Returns
-------
An instance of the class with the provided content as a template.
"""
return cls(cls._template_from_str(content), None)

@classmethod
def from_file(cls, path: Path):
"""
Create a Prompt instance from a file containing a Jinja template.
def __str__(self):
return self.template
Note: This method does not allow to include and inheritance to reference files
that are outside the folder or subfolders of the file given to `from_file`.
Parameters
----------
path : Path
The path to the file containing the Jinja template.
Returns
-------
Prompt
An instance of the Prompt class with the template loaded from the file.
"""
# We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is
# split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance)
return cls(cls._template_from_file(path), None)

@classmethod
def _template_from_str(_, content: str) -> jinja2.Template:
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(content)

# Add linebreak if there were any extra linebreaks that
# `cleandoc` would have removed
ends_with_linebreak = content.replace(" ", "").endswith("\n\n")
if ends_with_linebreak:
cleaned_template += "\n"

# Remove extra whitespaces, except those that immediately follow a newline symbol.
# This is necessary to avoid introducing whitespaces after backslash `\` characters
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
env.filters["source"] = get_fn_source
env.filters["signature"] = get_fn_signature
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

return env.from_string(cleaned_template)

@classmethod
def _template_from_file(_, path: Path) -> jinja2.Template:
file_directory = os.path.dirname(os.path.abspath(path))
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(file_directory),
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
return env.get_template(os.path.basename(path))


def prompt(fn: Callable) -> Prompt:
Expand Down Expand Up @@ -87,14 +167,18 @@ def prompt(fn: Callable) -> Prompt:
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = cast(str, docstring)
template = Prompt._template_from_str(cast(str, docstring))

return Prompt(template, signature)


def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
def render(
template: str, **values: Optional[Dict[str, Any]]
) -> str: # pragma: no cover
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
[DEPRECATED] Using `render(str)` is deprecated.
This function removes extra whitespaces and linebreaks from templates to
allow users to enter prompts more naturally than if they used Python's
constructs directly. See the examples for a detailed explanation.
Expand All @@ -105,12 +189,12 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
Outlines follow Jinja2's syntax
>>> import outlines
>>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis")
>>> outline = outlines.prompts.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis")
I like tomatoes and tennis
If the first line of the template is empty, `render` removes it
>>> from outlines import render
>>> from outlines.prompts import render
>>>
>>> tpl = '''
... A new string'''
Expand Down Expand Up @@ -174,7 +258,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
Parameters
----------
template
A string that contains a template written with the Jinja2 syntax.
A Jinja2 template.
**values
Map from the variables in the template to their value.
Expand All @@ -183,40 +267,12 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
A string that contains the rendered template.
"""
jinja_template = create_jinja_template(template)
return jinja_template.render(**values)


def create_jinja_template(template: str):
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(template)

# Add linebreak if there were any extra linebreaks that
# `cleandoc` would have removed
ends_with_linebreak = template.replace(" ", "").endswith("\n\n")
if ends_with_linebreak:
cleaned_template += "\n"

# Remove extra whitespaces, except those that immediately follow a newline symbol.
# This is necessary to avoid introducing whitespaces after backslash `\` characters
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

env = Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=StrictUndefined,
warnings.warn(
"Using `render(str)` is deprecated.",
DeprecationWarning,
)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
env.filters["source"] = get_fn_source
env.filters["signature"] = get_fn_signature
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

jinja_template = env.from_string(cleaned_template)
return jinja_template
template = Prompt._template_from_str(template)
return template.render(**values)


def get_fn_name(fn: Callable):
Expand Down Expand Up @@ -301,10 +357,10 @@ def get_schema_dict(model: Dict):
return json.dumps(model, indent=2)


@get_schema.register(type(BaseModel))
def get_schema_pydantic(model: Type[BaseModel]):
@get_schema.register(type(pydantic.BaseModel))
def get_schema_pydantic(model: Type[pydantic.BaseModel]):
"""Return the schema of a Pydantic model."""
if not type(model) == type(BaseModel):
if not isinstance(model, type(pydantic.BaseModel)):
raise TypeError("The `schema` filter only applies to Pydantic models.")

if hasattr(model, "model_json_schema"):
Expand Down
Loading

0 comments on commit 32587ee

Please sign in to comment.