Skip to content

Commit

Permalink
dummy provider and threading bugfix (#418)
Browse files Browse the repository at this point in the history
* example and bugfix to queue full

* minor

* typo

* PR comments

---------

Co-authored-by: rshih32 <[email protected]>
  • Loading branch information
piotrm0 and rshih32 authored Sep 7, 2023
1 parent bbb86a5 commit 1ff1222
Show file tree
Hide file tree
Showing 14 changed files with 250 additions and 40 deletions.
107 changes: 107 additions & 0 deletions trulens_eval/examples/dummy_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dummy Example\n",
"\n",
"This notebook shows the use of the dummy feedback function provider which\n",
"behaves like the huggingface provider except it does not actually perform any\n",
"network calls and just produces constant results. It can be used to prototype\n",
"feedback function wiring for your apps before invoking potentially slow (to\n",
"run/to load) feedback functions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import threading\n",
"\n",
"from examples.frameworks.custom.custom_app import CustomApp\n",
"\n",
"from trulens_eval import Feedback\n",
"from trulens_eval import Tru\n",
"from trulens_eval.feedback.provider.hugs import Dummy\n",
"from trulens_eval.tru_custom_app import TruCustomApp\n",
"from trulens_eval.utils.threading import TP\n",
"\n",
"tru = Tru()\n",
"\n",
"tru.reset_database()\n",
"\n",
"tru.start_dashboard()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# hugs = Huggingface()\n",
"hugs = Dummy()\n",
"\n",
"f_positive_sentiment = Feedback(hugs.positive_sentiment).on_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create custom app:\n",
"ca = CustomApp()\n",
"\n",
"# Create trulens wrapper:\n",
"ta = TruCustomApp(\n",
" ca,\n",
" app_id=\"customapp\",\n",
" # feedback_mode=FeedbackMode.WITH_APP\n",
" feedbacks=[f_positive_sentiment]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with ta:\n",
" for i, q in enumerate([\"hello there\"] * 100):\n",
" # Track number of requests, number of threads, and number of promises to fulfull\n",
" print(f\"\\rrequest {i} \", end=\"\")\n",
" print(f\"thread count={threading.active_count()}, promises={TP().promises.qsize()}\", end=\"\")\n",
"\n",
" res = ca.respond_to_query(input=q)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py38_trulens",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
19 changes: 1 addition & 18 deletions trulens_eval/examples/frameworks/custom/custom_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,6 @@
"This example uses several other python files in the same folder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"from pathlib import Path\n",
"import sys\n",
"\n",
"# If running from github repo, can use this:\n",
"sys.path.append(str(Path().cwd().parent.parent.parent.resolve()))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -124,9 +109,7 @@
"source": [
"# Start the dasshboard. If you running from github repo, you will need to adjust\n",
"# the path the dashboard streamlit app starts in by providing the _dev argument.\n",
"tru.start_dashboard(\n",
" force=True, _dev=Path().cwd().parent.parent.parent.resolve()\n",
")"
"tru.start_dashboard()"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion trulens_eval/examples/frameworks/custom/custom_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ def __init__(self, model: str = "derp"):

@instrument
def generate(self, prompt: str):
sleep(0.1)
sleep(0.01)

return "herp " + prompt[::-1] + " derp"
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class CustomRetriever:

# @instrument
def retrieve_chunks(self, data):
sleep(0.15)
sleep(0.015)

return [
f"Relevant chunk: {data.upper()}", f"Relevant chunk: {data[::-1]}"
Expand Down
6 changes: 6 additions & 0 deletions trulens_eval/trulens_eval/feedback/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def _next_unselected_arg_name(self):
f"Feedback function `{self.imp.__name__}` has `self` as argument. "
"Perhaps it is static method or its Provider class was not initialized?"
)
if len(par_names) == 0:
raise TypeError(
f"Feedback implementation {self.imp} with signature {sig} has no more inputs. "
"Perhaps you meant to evalute it on App output only instead of app input and output?"
)

return par_names[0]
else:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions trulens_eval/trulens_eval/feedback/provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class Config:

endpoint: Optional[Endpoint]

def __init__(self, *args, **kwargs):
def __init__(self, name: str = None, **kwargs):
# for WithClassInfo:
kwargs['obj'] = self

super().__init__(*args, **kwargs)
super().__init__(name=name, **kwargs)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from trulens_eval.feedback.provider.endpoint.base import Endpoint
from trulens_eval.feedback.provider.endpoint.base import Endpoint, DummyEndpoint
from trulens_eval.feedback.provider.endpoint.hugs import HuggingfaceEndpoint
from trulens_eval.feedback.provider.endpoint.openai import OpenAIEndpoint

__all__ = ['Endpoint', 'HuggingfaceEndpoint', 'OpenAIEndpoint']
__all__ = ['Endpoint', 'DummyEndpoint', 'HuggingfaceEndpoint', 'OpenAIEndpoint']
86 changes: 84 additions & 2 deletions trulens_eval/trulens_eval/feedback/provider/endpoint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from pprint import PrettyPrinter
from queue import Queue
import random
from threading import Thread
from time import sleep
from types import AsyncGeneratorType
Expand Down Expand Up @@ -97,9 +98,9 @@ class Config:
# Thread that fills the queue at the appropriate rate.
pace_thread: Thread = pydantic.Field(exclude=True)

def __new__(cls, name: str, *args, **kwargs):
def __new__(cls, *args, name: str=None, **kwargs):
return super(SingletonPerName, cls).__new__(
SerialModel, name=name, *args, **kwargs
SerialModel, *args, name=name, **kwargs
)

def __init__(self, *args, name: str, callback_class: Any, **kwargs):
Expand Down Expand Up @@ -772,3 +773,84 @@ def wrapper(*args, **kwargs):
logger.debug(f"Instrumenting {func.__name__} for {self.name} .")

return w


class DummyEndpoint(Endpoint):
"""
Endpoint for testing purposes. Should not make any network calls.
"""

# Pretend the model we are querying is loading as is in huggingface.
is_loading: bool = True

def __new__(cls, *args, **kwargs):
return super(Endpoint, cls).__new__(cls, name="dummyendpoint")

def __init__(self, name: str = "dummyendpoint", **kwargs):
if hasattr(self, "callback_class"):
# Already created with SingletonPerName mechanism
return

kwargs['name'] = name
kwargs['callback_class'] = EndpointCallback
kwargs['rpm'] = DEFAULT_RPM * 10

super().__init__(**kwargs)

def post(
self, url: str, payload: JSON, timeout: Optional[int] = None
) -> Any:
# classification results only, like from huggingface

self.pace_me()

# pretend to do this:
"""
ret = requests.post(
url, json=payload, timeout=timeout, headers=self.post_headers
)
"""

if self.is_loading:
# "model loading message"
j = dict(estimated_time=1.2345)
self.is_loading = False

elif random.randint(a=0, b=50) == 0:
# randomly overloaded
j = dict(error="overloaded")

else:
# otherwise a constant success

j = [[
{'label': 'LABEL_1', 'score': 0.6034979224205017},
{'label': 'LABEL_2', 'score': 0.2648237645626068},
{'label': 'LABEL_0', 'score': 0.13167837262153625}
]]

# The rest is the same as in Endpoint:

# Huggingface public api sometimes tells us that a model is loading and
# how long to wait:
if "estimated_time" in j:
wait_time = j['estimated_time']
logger.error(f"Waiting for {j} ({wait_time}) second(s).")
sleep(wait_time + 2)
return self.post(url, payload)

if isinstance(j, Dict) and "error" in j:
error = j['error']
logger.error(f"API error: {j}.")
if error == "overloaded":
logger.error("Waiting for overloaded API before trying again.")
sleep(10)
return self.post(url, payload)
else:
raise RuntimeError(error)

assert isinstance(
j, Sequence
) and len(j) > 0, f"Post did not return a sequence: {j}"

return j[0]
3 changes: 2 additions & 1 deletion trulens_eval/trulens_eval/feedback/provider/endpoint/hugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class HuggingfaceEndpoint(Endpoint, WithClassInfo):
"""
Huggingface. Instruments the requests.post method for requests to
"https://api-inference.huggingface.co".
"""
"""

def __new__(cls, *args, **kwargs):
return super(Endpoint, cls).__new__(cls, name="huggingface")
Expand Down Expand Up @@ -79,3 +79,4 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._instrument_class(requests, "post")

23 changes: 21 additions & 2 deletions trulens_eval/trulens_eval/feedback/provider/hugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from trulens_eval.feedback.provider.base import Provider
from trulens_eval.feedback.provider.endpoint import HuggingfaceEndpoint
from trulens_eval.feedback.provider.endpoint.base import Endpoint
from trulens_eval.feedback.provider.endpoint.base import DummyEndpoint
from trulens_eval.utils.threading import TP

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,6 +52,8 @@ def wrapper(*args, **kwargs):
raise ValueError(f"{pident} must be non-empty.")

return func(*bindings.args, **bindings.kwargs)

wrapper.__signature__ = sig

return wrapper

Expand All @@ -60,7 +63,7 @@ class Huggingface(Provider):
"""
endpoint: Endpoint

def __init__(self, endpoint=None, **kwargs):
def __init__(self, name: str = None, endpoint=None, **kwargs):
# NOTE(piotrm): pydantic adds endpoint to the signature of this
# constructor if we don't include it explicitly, even though we set it
# down below. Adding it as None here as a temporary hack.
Expand All @@ -79,8 +82,15 @@ def __init__(self, endpoint=None, **kwargs):
endpoint (Endpoint): Internal Usage for DB serialization
"""

kwargs['name'] = name

self_kwargs = dict()
self_kwargs['endpoint'] = HuggingfaceEndpoint(**kwargs)
if endpoint is None:
self_kwargs['endpoint'] = HuggingfaceEndpoint(**kwargs)
else:
self_kwargs['endpoint'] = endpoint

self_kwargs['name'] = name or "huggingface"

super().__init__(
**self_kwargs
Expand Down Expand Up @@ -257,3 +267,12 @@ def _doc_groundedness(self, premise: str, hypothesis: str) -> float:
for label in hf_response:
if label['label'] == 'entailment':
return label['score']

class Dummy(Huggingface):
def __init__(self, name: str = None, **kwargs):
kwargs['name'] = name or "dummyhugs"
kwargs['endpoint'] = DummyEndpoint(name="dummyendhugspoint")

super().__init__(
**kwargs
)
6 changes: 2 additions & 4 deletions trulens_eval/trulens_eval/tru_custom_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,13 @@ class will not be found by trulens.

import logging
from pprint import PrettyPrinter
from typing import Any, Callable, ClassVar, Iterable, Set
from typing import Any, Callable, ClassVar, Set

from pydantic import Field

from trulens_eval.app import App
from trulens_eval.instruments import Instrument
from trulens_eval.instruments import instrument as base_instrument
from trulens_eval.utils.pyschema import Class
from trulens_eval.utils.pyschema import FunctionOrMethod
from trulens_eval.utils.serial import JSONPath
Expand Down Expand Up @@ -392,9 +393,6 @@ def __getattr__(self, __name: str) -> Any:
)


from trulens_eval.instruments import instrument as base_instrument


class instrument(base_instrument):
"""
Decorator for marking methods to be instrumented in custom classes that are
Expand Down
Loading

0 comments on commit 1ff1222

Please sign in to comment.