Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xz/plugin embedding dump #101

Merged
merged 12 commits into from
Jan 10, 2024
68 changes: 68 additions & 0 deletions scripts/plugin_mgt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import argparse
import os
import sys

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

from injector import Injector

from taskweaver.code_interpreter.code_generator.plugin_selection import PluginSelector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory.plugin import PluginModule

parser = argparse.ArgumentParser()
parser.add_argument(
"--project_dir",
type=str,
default=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"project",
),
help="The project directory for the TaskWeaver.",
)
parser.add_argument("--refresh", action="store_true", help="Refresh plugin embeddings.")
parser.add_argument("--show", action="store_true", help="Show plugin information.")

args = parser.parse_args()


class PluginManager:
def __init__(self):
app_injector = Injector([LoggingModule, PluginModule])
app_config = AppConfigSource(
config_file_path=os.path.join(
args.project_dir,
"taskweaver_config.json",
),
app_base_path=args.project_dir,
)
app_injector.binder.bind(AppConfigSource, to=app_config)
self.plugin_selector = app_injector.create_object(PluginSelector)

def refresh(self):
self.plugin_selector.refresh()
print("Plugin embeddings refreshed.")

def show(self):
plugin_list = self.plugin_selector.available_plugins
if len(plugin_list) == 0:
print("No available plugins.")
return
for p in plugin_list:
print(f"* Plugin Name: {p.name}")
print(f"* Plugin Description: {p.spec.description}")
print(f"* Plugin Embedding dim: {len(p.meta_data.embedding)}")
print(f"* Plugin Embedding model: {p.meta_data.embedding_model}")
print(f"* Plugin Args: {p.spec.args}")
print(f"* Plugin Returns: {p.spec.returns}")
print(f"_________________________________")


if __name__ == "__main__":
plugin_manager = PluginManager()
if args.refresh:
plugin_manager.refresh()
if args.show:
plugin_manager.show()
12 changes: 6 additions & 6 deletions taskweaver/code_interpreter/code_generator/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(

if self.config.enable_auto_plugin_selection:
self.plugin_selector = PluginSelector(plugin_registry, self.llm_api)
self.plugin_selector.generate_plugin_embeddings()
logger.info("Plugin embeddings generated")
self.plugin_selector.load_plugin_embeddings()
logger.info("Plugin embeddings loaded")
self.selected_plugin_pool = SelectedPluginPool()

def configure_verification(
Expand Down Expand Up @@ -272,10 +272,10 @@ def compose_conversation(

def select_plugins_for_prompt(
self,
user_query: str,
query: str,
) -> List[PluginEntry]:
selected_plugins = self.plugin_selector.plugin_select(
user_query,
query,
self.config.auto_plugin_selection_topk,
)
self.selected_plugin_pool.add_selected_plugins(selected_plugins)
Expand All @@ -300,10 +300,10 @@ def reply(
)

# obtain the user query from the last round
user_query = rounds[-1].user_query
query = rounds[-1].post_list[-1].message

if self.config.enable_auto_plugin_selection:
self.plugin_pool = self.select_plugins_for_prompt(user_query)
self.plugin_pool = self.select_plugins_for_prompt(query)

prompt = self.compose_prompt(rounds, self.plugin_pool)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(

if self.config.enable_auto_plugin_selection:
self.plugin_selector = PluginSelector(plugin_registry, self.llm_api)
self.plugin_selector.generate_plugin_embeddings()
logger.info("Plugin embeddings generated")
self.plugin_selector.load_plugin_embeddings()
logger.info("Plugin embeddings loaded")
self.selected_plugin_pool = SelectedPluginPool()

def select_plugins_for_prompt(
Expand Down
63 changes: 56 additions & 7 deletions taskweaver/code_interpreter/code_generator/plugin_selection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict, List

import numpy as np
Expand All @@ -6,6 +7,7 @@

from taskweaver.llm import LLMApi
from taskweaver.memory.plugin import PluginEntry, PluginRegistry
from taskweaver.utils import generate_md5_hash, write_yaml


class SelectedPluginPool:
Expand Down Expand Up @@ -70,13 +72,60 @@ def __init__(
self.llm_api = llm_api
self.plugin_embedding_dict: Dict[str, List[float]] = {}

def generate_plugin_embeddings(self):
plugin_intro_text_list: List[str] = []
for p in self.available_plugins:
plugin_intro_text_list.append(p.name + ": " + p.spec.description)
plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list)
for i, p in enumerate(self.available_plugins):
self.plugin_embedding_dict[p.name] = plugin_embeddings[i]
self.exception_message_for_refresh = (
"Please cd to the `script` directory and "
"run `python -m plugin_mgt --refresh` to refresh the plugin embedding."
)

self.meta_file_dir = os.path.join(os.path.dirname(plugin_registry.file_glob), ".meta")
if not os.path.exists(self.meta_file_dir):
os.makedirs(self.meta_file_dir)

def refresh(self):
plugins_to_embedded = []
for idx, p in enumerate(self.available_plugins):
if (
len(p.meta_data.embedding) > 0
and p.meta_data.embedding_model == self.llm_api.embedding_service.config.embedding_model
and p.meta_data.md5hash == generate_md5_hash(p.spec.name + p.spec.description)
):
continue
else:
plugins_to_embedded.append((idx, p.name + ": " + p.spec.description))

if len(plugins_to_embedded) == 0:
print("All plugins are up-to-date.")
return

plugin_embeddings = self.llm_api.get_embedding_list([text for idx, text in plugins_to_embedded])

for i, embedding in enumerate(plugin_embeddings):
p = self.available_plugins[plugins_to_embedded[i][0]]
p.meta_data.embedding = embedding
p.meta_data.embedding_model = self.llm_api.embedding_service.config.embedding_model
p.meta_data.md5hash = generate_md5_hash(p.spec.name + p.spec.description)
write_yaml(p.meta_data.path, p.meta_data.to_dict())

def load_plugin_embeddings(self):
for idx, p in enumerate(self.available_plugins):
# check if the plugin has embedding
assert len(p.meta_data.embedding) > 0, (
f"Plugin {p.name} has no embedding. " + self.exception_message_for_refresh
)
# check if the plugin is using the same embedding model as the current session
assert p.meta_data.embedding_model == self.llm_api.embedding_service.config.embedding_model, (
f"Plugin {p.name} is using embedding model {p.meta_data.embedding_model}, "
f"which is different from the one used by current session"
f" ({self.llm_api.embedding_service.config.embedding_model}). "
f"Please use the same embedding model or refresh the plugin embedding."
+ self.exception_message_for_refresh
)
# check if the plugin has been modified
assert p.meta_data.md5hash == generate_md5_hash(p.spec.name + p.spec.description), (
f"Plugin {p.name} has been modified. " + self.exception_message_for_refresh
)

self.plugin_embedding_dict[p.name] = p.meta_data.embedding

def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]:
user_query_embedding = np.array(self.llm_api.get_embedding(user_query))
Expand Down
60 changes: 56 additions & 4 deletions taskweaver/memory/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@
from taskweaver.utils import read_yaml, validate_yaml


@dataclass
class PluginMetaData:
name: str
embedding: List[float] = field(default_factory=list)
embedding_model: Optional[str] = None
path: Optional[str] = None
md5hash: Optional[str] = None

@staticmethod
def from_dict(d: Dict[str, Any]):
return PluginMetaData(
name=d["name"],
embedding=d["embedding"] if "embedding" in d else [],
embedding_model=d["embedding_model"] if "embedding_model" in d else None,
path=d["path"] if "path" in d else None,
md5hash=d["md5hash"] if "md5hash" in d else None,
)

def to_dict(self):
return {
"name": self.name,
"embedding": self.embedding,
"embedding_model": self.embedding_model,
"path": self.path,
"md5hash": self.md5hash,
}


@dataclass
class PluginParameter:
"""PluginParameter is the data structure for plugin parameters (including arguments and return values.)"""
Expand Down Expand Up @@ -41,6 +69,14 @@ def line(cnt: str):

return "\n".join(lines)

def to_dict(self):
return {
"name": self.name,
"type": self.type,
"required": self.required,
"description": self.description,
}


@dataclass
class PluginSpec:
Expand All @@ -50,7 +86,6 @@ class PluginSpec:
description: str = ""
args: List[PluginParameter] = field(default_factory=list)
returns: List[PluginParameter] = field(default_factory=list)
embedding: List[float] = field(default_factory=list)

@staticmethod
def from_dict(d: Dict[str, Any]):
Expand All @@ -59,9 +94,16 @@ def from_dict(d: Dict[str, Any]):
description=d["description"],
args=[PluginParameter.from_dict(p) for p in d["parameters"]],
returns=[PluginParameter.from_dict(p) for p in d["returns"]],
embedding=[],
)

def to_dict(self):
return {
"name": self.name,
"description": self.description,
"parameters": [p.to_dict() for p in self.args],
"returns": [p.to_dict() for p in self.returns],
}

def format_prompt(self) -> str:
def normalize_type(t: str) -> str:
if t.lower() == "string":
Expand Down Expand Up @@ -120,14 +162,22 @@ class PluginEntry:
config: Dict[str, Any]
required: bool
enabled: bool = True
meta_data: Optional[PluginMetaData] = None

@staticmethod
def from_yaml_file(path: str) -> Optional["PluginEntry"]:
content = read_yaml(path)
return PluginEntry.from_yaml_content(content)
yaml_file_name = os.path.basename(path)
meta_file_path = os.path.join(os.path.dirname(path), ".meta", f"meta_{yaml_file_name}")
if os.path.exists(meta_file_path):
meta_data = PluginMetaData.from_dict(read_yaml(meta_file_path))
meta_data.path = meta_file_path
else:
meta_data = PluginMetaData(name=os.path.splitext(yaml_file_name)[0], path=meta_file_path)
return PluginEntry.from_yaml_content(content, meta_data)

@staticmethod
def from_yaml_content(content: Dict) -> Optional["PluginEntry"]:
def from_yaml_content(content: Dict, meta_data: Optional[PluginMetaData] = None) -> Optional["PluginEntry"]:
do_validate = False
valid_state = False
if do_validate:
Expand All @@ -142,6 +192,7 @@ def from_yaml_content(content: Dict) -> Optional["PluginEntry"]:
required=content.get("required", False),
enabled=content.get("enabled", True),
plugin_only=content.get("plugin_only", False),
meta_data=meta_data,
)
return None

Expand All @@ -156,6 +207,7 @@ def to_dict(self):
"config": self.config,
"required": self.required,
"enabled": self.enabled,
"plugin_only": self.plugin_only,
}

def format_function_calling(self) -> Dict:
Expand Down
15 changes: 15 additions & 0 deletions taskweaver/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import secrets
from datetime import datetime
from hashlib import md5
from typing import Any, Dict


Expand All @@ -24,6 +25,16 @@ def read_yaml(path: str) -> Dict[str, Any]:
raise ValueError(f"Yaml loading failed due to: {e}")


def write_yaml(path: str, content: Dict[str, Any]):
import yaml

try:
with open(path, "w") as file:
yaml.safe_dump(content, file, sort_keys=False)
except Exception as e:
raise ValueError(f"Yaml writing failed due to: {e}")


def validate_yaml(content: Any, schema: str) -> bool:
import jsonschema

Expand Down Expand Up @@ -58,3 +69,7 @@ def json_dumps(obj: Any) -> str:

def json_dump(obj: Any, fp: Any):
json.dump(obj, fp, cls=EnhancedJSONEncoder)


def generate_md5_hash(content: str) -> str:
return md5(content.encode()).hexdigest()
Loading