From 057e403eba418f7daa1f810d4b4cd8de821c9185 Mon Sep 17 00:00:00 2001 From: Yuwei Yan Date: Sat, 28 Dec 2024 19:38:40 +0800 Subject: [PATCH] Decouple interactiontool with simulator --- .../agent/recommendation_agent.py | 1 - websocietysimulator/agent/simulation_agent.py | 1 - websocietysimulator/simulator.py | 11 ++-- websocietysimulator/tools/interaction_tool.py | 50 +++---------------- 4 files changed, 16 insertions(+), 47 deletions(-) diff --git a/websocietysimulator/agent/recommendation_agent.py b/websocietysimulator/agent/recommendation_agent.py index bf6230df..60640a06 100644 --- a/websocietysimulator/agent/recommendation_agent.py +++ b/websocietysimulator/agent/recommendation_agent.py @@ -21,7 +21,6 @@ def insert_task(self, task): if not task: raise ValueError("The task cannot be None.") self.task = task.to_dict() - self.interaction_tool.set_task(self.task) def forward(self) -> List[str]: """ diff --git a/websocietysimulator/agent/simulation_agent.py b/websocietysimulator/agent/simulation_agent.py index 7283cfe8..46f26509 100644 --- a/websocietysimulator/agent/simulation_agent.py +++ b/websocietysimulator/agent/simulation_agent.py @@ -18,7 +18,6 @@ def insert_task(self, task): if not task: raise ValueError("The task cannot be None.") self.task = task.to_dict() - self.interaction_tool.set_task(self.task) def forward(self) -> Dict[str, Any]: """ diff --git a/websocietysimulator/simulator.py b/websocietysimulator/simulator.py index 71b93007..91dd77dc 100644 --- a/websocietysimulator/simulator.py +++ b/websocietysimulator/simulator.py @@ -12,7 +12,7 @@ import numpy as np class Simulator: - def __init__(self, data_dir: str, device: str = "auto"): + def __init__(self, data_dir: str = None, device: str = "auto"): """ Initialize the Simulator. Args: @@ -21,8 +21,10 @@ def __init__(self, data_dir: str, device: str = "auto"): """ logging.info("Start initializing Simulator") self.data_dir = data_dir - - self.interaction_tool = InteractionTool(data_dir) + if data_dir is None: + self.interaction_tool = None + else: + self.interaction_tool = InteractionTool(data_dir) self.tasks = [] # List to store tasks self.groundtruth_data = [] # List to store groundtruth data @@ -34,6 +36,9 @@ def __init__(self, data_dir: str, device: str = "auto"): self.evaluation_results = [] logging.info("Simulator initialized") + def set_interaction_tool(self, interaction_tool: InteractionTool): + self.interaction_tool = interaction_tool + def set_task_and_groundtruth(self, task_dir: str, groundtruth_dir: str): """ Load tasks from a directory. diff --git a/websocietysimulator/tools/interaction_tool.py b/websocietysimulator/tools/interaction_tool.py index 45e7b2a5..b3a1fd53 100644 --- a/websocietysimulator/tools/interaction_tool.py +++ b/websocietysimulator/tools/interaction_tool.py @@ -17,7 +17,6 @@ def __init__(self, data_dir: str): self.user_data = self._load_data('user.json') self.tip_data = self._load_data('tip.json') self.checkin_data = self._load_data('checkin.json') - self.task = None def _load_data(self, filename: str) -> pd.DataFrame: """Load a dataset as a Pandas DataFrame.""" @@ -26,39 +25,16 @@ def _load_data(self, filename: str) -> pd.DataFrame: data = [json.loads(line) for line in file] return pd.DataFrame(data) - def set_task(self, task: Dict[str, Any]): - """ - Update the context of the tool based on a task. - Args: - task: Task dictionary with context parameters. - """ - self.task = task - - def _ensure_task(self): - """Ensure that a task has been set before any action.""" - if not self.task: - raise RuntimeError("No task has been set. Please set a task before interacting.") - - def get_user(self, user_id: Optional[str] = None) -> Optional[Dict]: - """Fetch user data based on user_id or task.""" - self._ensure_task() - - user_id = user_id or self.task.get('user_id') if self.task else None - if not user_id: - return None - + def get_user(self, user_id: str) -> Optional[Dict]: + """Fetch user data based on user_id.""" user = self.user_data[self.user_data['user_id'] == user_id] if user.empty: return None user = user.to_dict(orient='records')[0] return user - def get_business(self, business_id: Optional[str] = None) -> Optional[Dict]: - """Fetch business data based on business_id or task.""" - self._ensure_task() # Ensure task is set - business_id = business_id or self.task.get('business_id') if self.task else None - if not business_id: - return None + def get_business(self, business_id: str) -> Optional[Dict]: + """Fetch business data based on business_id.""" business = self.business_data[self.business_data['business_id'] == business_id] return business.to_dict(orient='records')[0] if not business.empty else None @@ -69,26 +45,20 @@ def get_reviews( review_id: Optional[str] = None ) -> List[Dict]: """Fetch reviews filtered by various parameters.""" - self._ensure_task() - + if business_id is None and user_id is None and review_id is None: + return [] reviews = self.review_data - if review_id: reviews = reviews[reviews['review_id'] == review_id] else: - business_id = business_id or (self.task.get('business_id') if self.task else None) - user_id = user_id or (self.task.get('user_id') if self.task else None) if business_id: reviews = reviews[reviews['business_id'] == business_id] if user_id: reviews = reviews[reviews['user_id'] == user_id] return reviews.to_dict(orient='records') - def get_tips(self, business_id: Optional[str] = None, user_id: Optional[str] = None) -> List[Dict]: + def get_tips(self, business_id: str, user_id: str) -> List[Dict]: """Fetch tips with date filter.""" - self._ensure_task() - business_id = business_id or (self.task.get('business') if self.task else None) - user_id = user_id or (self.task.get('user') if self.task else None) tips = self.tip_data if business_id: tips = tips[tips['business_id'] == business_id] @@ -96,12 +66,8 @@ def get_tips(self, business_id: Optional[str] = None, user_id: Optional[str] = N tips = tips[tips['user_id'] == user_id] return tips.to_dict(orient='records') - def get_checkins(self, business_id: Optional[str] = None) -> List[Dict]: + def get_checkins(self, business_id: str) -> List[Dict]: """Fetch checkins with date filter.""" - self._ensure_task() - business_id = business_id or (self.task.get('business') if self.task else None) - if not business_id: - return [] checkins = self.checkin_data checkins = checkins[checkins['business_id'] == business_id] return checkins.to_dict(orient='records')