Skip to content

Commit

Permalink
Xz/plugin embedding dump (microsoft#101)
Browse files Browse the repository at this point in the history
- add plugin embedding persistence for plugin selection
- add plugin manager script 
- fix plugin selection query bug
  • Loading branch information
zhangxu0307 authored Jan 10, 2024
1 parent 5de8f95 commit 945e860
Show file tree
Hide file tree
Showing 12 changed files with 3,319 additions and 20 deletions.
68 changes: 68 additions & 0 deletions scripts/plugin_mgt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import argparse
import os
import sys

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

from injector import Injector

from taskweaver.code_interpreter.code_generator.plugin_selection import PluginSelector
from taskweaver.config.config_mgt import AppConfigSource
from taskweaver.logging import LoggingModule
from taskweaver.memory.plugin import PluginModule

parser = argparse.ArgumentParser()
parser.add_argument(
"--project_dir",
type=str,
default=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"project",
),
help="The project directory for the TaskWeaver.",
)
parser.add_argument("--refresh", action="store_true", help="Refresh plugin embeddings.")
parser.add_argument("--show", action="store_true", help="Show plugin information.")

args = parser.parse_args()


class PluginManager:
def __init__(self):
app_injector = Injector([LoggingModule, PluginModule])
app_config = AppConfigSource(
config_file_path=os.path.join(
args.project_dir,
"taskweaver_config.json",
),
app_base_path=args.project_dir,
)
app_injector.binder.bind(AppConfigSource, to=app_config)
self.plugin_selector = app_injector.create_object(PluginSelector)

def refresh(self):
self.plugin_selector.refresh()
print("Plugin embeddings refreshed.")

def show(self):
plugin_list = self.plugin_selector.available_plugins
if len(plugin_list) == 0:
print("No available plugins.")
return
for p in plugin_list:
print(f"* Plugin Name: {p.name}")
print(f"* Plugin Description: {p.spec.description}")
print(f"* Plugin Embedding dim: {len(p.meta_data.embedding)}")
print(f"* Plugin Embedding model: {p.meta_data.embedding_model}")
print(f"* Plugin Args: {p.spec.args}")
print(f"* Plugin Returns: {p.spec.returns}")
print(f"_________________________________")


if __name__ == "__main__":
plugin_manager = PluginManager()
if args.refresh:
plugin_manager.refresh()
if args.show:
plugin_manager.show()
12 changes: 6 additions & 6 deletions taskweaver/code_interpreter/code_generator/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(

if self.config.enable_auto_plugin_selection:
self.plugin_selector = PluginSelector(plugin_registry, self.llm_api)
self.plugin_selector.generate_plugin_embeddings()
logger.info("Plugin embeddings generated")
self.plugin_selector.load_plugin_embeddings()
logger.info("Plugin embeddings loaded")
self.selected_plugin_pool = SelectedPluginPool()

def configure_verification(
Expand Down Expand Up @@ -272,10 +272,10 @@ def compose_conversation(

def select_plugins_for_prompt(
self,
user_query: str,
query: str,
) -> List[PluginEntry]:
selected_plugins = self.plugin_selector.plugin_select(
user_query,
query,
self.config.auto_plugin_selection_topk,
)
self.selected_plugin_pool.add_selected_plugins(selected_plugins)
Expand All @@ -300,10 +300,10 @@ def reply(
)

# obtain the user query from the last round
user_query = rounds[-1].user_query
query = rounds[-1].post_list[-1].message

if self.config.enable_auto_plugin_selection:
self.plugin_pool = self.select_plugins_for_prompt(user_query)
self.plugin_pool = self.select_plugins_for_prompt(query)

prompt = self.compose_prompt(rounds, self.plugin_pool)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(

if self.config.enable_auto_plugin_selection:
self.plugin_selector = PluginSelector(plugin_registry, self.llm_api)
self.plugin_selector.generate_plugin_embeddings()
logger.info("Plugin embeddings generated")
self.plugin_selector.load_plugin_embeddings()
logger.info("Plugin embeddings loaded")
self.selected_plugin_pool = SelectedPluginPool()

def select_plugins_for_prompt(
Expand Down
63 changes: 56 additions & 7 deletions taskweaver/code_interpreter/code_generator/plugin_selection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict, List

import numpy as np
Expand All @@ -6,6 +7,7 @@

from taskweaver.llm import LLMApi
from taskweaver.memory.plugin import PluginEntry, PluginRegistry
from taskweaver.utils import generate_md5_hash, write_yaml


class SelectedPluginPool:
Expand Down Expand Up @@ -70,13 +72,60 @@ def __init__(
self.llm_api = llm_api
self.plugin_embedding_dict: Dict[str, List[float]] = {}

def generate_plugin_embeddings(self):
plugin_intro_text_list: List[str] = []
for p in self.available_plugins:
plugin_intro_text_list.append(p.name + ": " + p.spec.description)
plugin_embeddings = self.llm_api.get_embedding_list(plugin_intro_text_list)
for i, p in enumerate(self.available_plugins):
self.plugin_embedding_dict[p.name] = plugin_embeddings[i]
self.exception_message_for_refresh = (
"Please cd to the `script` directory and "
"run `python -m plugin_mgt --refresh` to refresh the plugin embedding."
)

self.meta_file_dir = os.path.join(os.path.dirname(plugin_registry.file_glob), ".meta")
if not os.path.exists(self.meta_file_dir):
os.makedirs(self.meta_file_dir)

def refresh(self):
plugins_to_embedded = []
for idx, p in enumerate(self.available_plugins):
if (
len(p.meta_data.embedding) > 0
and p.meta_data.embedding_model == self.llm_api.embedding_service.config.embedding_model
and p.meta_data.md5hash == generate_md5_hash(p.spec.name + p.spec.description)
):
continue
else:
plugins_to_embedded.append((idx, p.name + ": " + p.spec.description))

if len(plugins_to_embedded) == 0:
print("All plugins are up-to-date.")
return

plugin_embeddings = self.llm_api.get_embedding_list([text for idx, text in plugins_to_embedded])

for i, embedding in enumerate(plugin_embeddings):
p = self.available_plugins[plugins_to_embedded[i][0]]
p.meta_data.embedding = embedding
p.meta_data.embedding_model = self.llm_api.embedding_service.config.embedding_model
p.meta_data.md5hash = generate_md5_hash(p.spec.name + p.spec.description)
write_yaml(p.meta_data.path, p.meta_data.to_dict())

def load_plugin_embeddings(self):
for idx, p in enumerate(self.available_plugins):
# check if the plugin has embedding
assert len(p.meta_data.embedding) > 0, (
f"Plugin {p.name} has no embedding. " + self.exception_message_for_refresh
)
# check if the plugin is using the same embedding model as the current session
assert p.meta_data.embedding_model == self.llm_api.embedding_service.config.embedding_model, (
f"Plugin {p.name} is using embedding model {p.meta_data.embedding_model}, "
f"which is different from the one used by current session"
f" ({self.llm_api.embedding_service.config.embedding_model}). "
f"Please use the same embedding model or refresh the plugin embedding."
+ self.exception_message_for_refresh
)
# check if the plugin has been modified
assert p.meta_data.md5hash == generate_md5_hash(p.spec.name + p.spec.description), (
f"Plugin {p.name} has been modified. " + self.exception_message_for_refresh
)

self.plugin_embedding_dict[p.name] = p.meta_data.embedding

def plugin_select(self, user_query: str, top_k: int = 5) -> List[PluginEntry]:
user_query_embedding = np.array(self.llm_api.get_embedding(user_query))
Expand Down
60 changes: 56 additions & 4 deletions taskweaver/memory/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@
from taskweaver.utils import read_yaml, validate_yaml


@dataclass
class PluginMetaData:
name: str
embedding: List[float] = field(default_factory=list)
embedding_model: Optional[str] = None
path: Optional[str] = None
md5hash: Optional[str] = None

@staticmethod
def from_dict(d: Dict[str, Any]):
return PluginMetaData(
name=d["name"],
embedding=d["embedding"] if "embedding" in d else [],
embedding_model=d["embedding_model"] if "embedding_model" in d else None,
path=d["path"] if "path" in d else None,
md5hash=d["md5hash"] if "md5hash" in d else None,
)

def to_dict(self):
return {
"name": self.name,
"embedding": self.embedding,
"embedding_model": self.embedding_model,
"path": self.path,
"md5hash": self.md5hash,
}


@dataclass
class PluginParameter:
"""PluginParameter is the data structure for plugin parameters (including arguments and return values.)"""
Expand Down Expand Up @@ -41,6 +69,14 @@ def line(cnt: str):

return "\n".join(lines)

def to_dict(self):
return {
"name": self.name,
"type": self.type,
"required": self.required,
"description": self.description,
}


@dataclass
class PluginSpec:
Expand All @@ -50,7 +86,6 @@ class PluginSpec:
description: str = ""
args: List[PluginParameter] = field(default_factory=list)
returns: List[PluginParameter] = field(default_factory=list)
embedding: List[float] = field(default_factory=list)

@staticmethod
def from_dict(d: Dict[str, Any]):
Expand All @@ -59,9 +94,16 @@ def from_dict(d: Dict[str, Any]):
description=d["description"],
args=[PluginParameter.from_dict(p) for p in d["parameters"]],
returns=[PluginParameter.from_dict(p) for p in d["returns"]],
embedding=[],
)

def to_dict(self):
return {
"name": self.name,
"description": self.description,
"parameters": [p.to_dict() for p in self.args],
"returns": [p.to_dict() for p in self.returns],
}

def format_prompt(self) -> str:
def normalize_type(t: str) -> str:
if t.lower() == "string":
Expand Down Expand Up @@ -120,14 +162,22 @@ class PluginEntry:
config: Dict[str, Any]
required: bool
enabled: bool = True
meta_data: Optional[PluginMetaData] = None

@staticmethod
def from_yaml_file(path: str) -> Optional["PluginEntry"]:
content = read_yaml(path)
return PluginEntry.from_yaml_content(content)
yaml_file_name = os.path.basename(path)
meta_file_path = os.path.join(os.path.dirname(path), ".meta", f"meta_{yaml_file_name}")
if os.path.exists(meta_file_path):
meta_data = PluginMetaData.from_dict(read_yaml(meta_file_path))
meta_data.path = meta_file_path
else:
meta_data = PluginMetaData(name=os.path.splitext(yaml_file_name)[0], path=meta_file_path)
return PluginEntry.from_yaml_content(content, meta_data)

@staticmethod
def from_yaml_content(content: Dict) -> Optional["PluginEntry"]:
def from_yaml_content(content: Dict, meta_data: Optional[PluginMetaData] = None) -> Optional["PluginEntry"]:
do_validate = False
valid_state = False
if do_validate:
Expand All @@ -142,6 +192,7 @@ def from_yaml_content(content: Dict) -> Optional["PluginEntry"]:
required=content.get("required", False),
enabled=content.get("enabled", True),
plugin_only=content.get("plugin_only", False),
meta_data=meta_data,
)
return None

Expand All @@ -156,6 +207,7 @@ def to_dict(self):
"config": self.config,
"required": self.required,
"enabled": self.enabled,
"plugin_only": self.plugin_only,
}

def format_function_calling(self) -> Dict:
Expand Down
15 changes: 15 additions & 0 deletions taskweaver/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import secrets
from datetime import datetime
from hashlib import md5
from typing import Any, Dict


Expand All @@ -24,6 +25,16 @@ def read_yaml(path: str) -> Dict[str, Any]:
raise ValueError(f"Yaml loading failed due to: {e}")


def write_yaml(path: str, content: Dict[str, Any]):
import yaml

try:
with open(path, "w") as file:
yaml.safe_dump(content, file, sort_keys=False)
except Exception as e:
raise ValueError(f"Yaml writing failed due to: {e}")


def validate_yaml(content: Any, schema: str) -> bool:
import jsonschema

Expand Down Expand Up @@ -58,3 +69,7 @@ def json_dumps(obj: Any) -> str:

def json_dump(obj: Any, fp: Any):
json.dump(obj, fp, cls=EnhancedJSONEncoder)


def generate_md5_hash(content: str) -> str:
return md5(content.encode()).hexdigest()
Loading

0 comments on commit 945e860

Please sign in to comment.