Skip to content

Commit

Permalink
Decouple interactiontool with simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuwei Yan committed Dec 28, 2024
1 parent 3746107 commit 057e403
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 47 deletions.
1 change: 0 additions & 1 deletion websocietysimulator/agent/recommendation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def insert_task(self, task):
if not task:
raise ValueError("The task cannot be None.")
self.task = task.to_dict()
self.interaction_tool.set_task(self.task)

def forward(self) -> List[str]:
"""
Expand Down
1 change: 0 additions & 1 deletion websocietysimulator/agent/simulation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def insert_task(self, task):
if not task:
raise ValueError("The task cannot be None.")
self.task = task.to_dict()
self.interaction_tool.set_task(self.task)

def forward(self) -> Dict[str, Any]:
"""
Expand Down
11 changes: 8 additions & 3 deletions websocietysimulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np

class Simulator:
def __init__(self, data_dir: str, device: str = "auto"):
def __init__(self, data_dir: str = None, device: str = "auto"):
"""
Initialize the Simulator.
Args:
Expand All @@ -21,8 +21,10 @@ def __init__(self, data_dir: str, device: str = "auto"):
"""
logging.info("Start initializing Simulator")
self.data_dir = data_dir

self.interaction_tool = InteractionTool(data_dir)
if data_dir is None:
self.interaction_tool = None
else:
self.interaction_tool = InteractionTool(data_dir)

self.tasks = [] # List to store tasks
self.groundtruth_data = [] # List to store groundtruth data
Expand All @@ -34,6 +36,9 @@ def __init__(self, data_dir: str, device: str = "auto"):
self.evaluation_results = []
logging.info("Simulator initialized")

def set_interaction_tool(self, interaction_tool: InteractionTool):
self.interaction_tool = interaction_tool

def set_task_and_groundtruth(self, task_dir: str, groundtruth_dir: str):
"""
Load tasks from a directory.
Expand Down
50 changes: 8 additions & 42 deletions websocietysimulator/tools/interaction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(self, data_dir: str):
self.user_data = self._load_data('user.json')
self.tip_data = self._load_data('tip.json')
self.checkin_data = self._load_data('checkin.json')
self.task = None

def _load_data(self, filename: str) -> pd.DataFrame:
"""Load a dataset as a Pandas DataFrame."""
Expand All @@ -26,39 +25,16 @@ def _load_data(self, filename: str) -> pd.DataFrame:
data = [json.loads(line) for line in file]
return pd.DataFrame(data)

def set_task(self, task: Dict[str, Any]):
"""
Update the context of the tool based on a task.
Args:
task: Task dictionary with context parameters.
"""
self.task = task

def _ensure_task(self):
"""Ensure that a task has been set before any action."""
if not self.task:
raise RuntimeError("No task has been set. Please set a task before interacting.")

def get_user(self, user_id: Optional[str] = None) -> Optional[Dict]:
"""Fetch user data based on user_id or task."""
self._ensure_task()

user_id = user_id or self.task.get('user_id') if self.task else None
if not user_id:
return None

def get_user(self, user_id: str) -> Optional[Dict]:
"""Fetch user data based on user_id."""
user = self.user_data[self.user_data['user_id'] == user_id]
if user.empty:
return None
user = user.to_dict(orient='records')[0]
return user

def get_business(self, business_id: Optional[str] = None) -> Optional[Dict]:
"""Fetch business data based on business_id or task."""
self._ensure_task() # Ensure task is set
business_id = business_id or self.task.get('business_id') if self.task else None
if not business_id:
return None
def get_business(self, business_id: str) -> Optional[Dict]:
"""Fetch business data based on business_id."""
business = self.business_data[self.business_data['business_id'] == business_id]
return business.to_dict(orient='records')[0] if not business.empty else None

Expand All @@ -69,39 +45,29 @@ def get_reviews(
review_id: Optional[str] = None
) -> List[Dict]:
"""Fetch reviews filtered by various parameters."""
self._ensure_task()

if business_id is None and user_id is None and review_id is None:
return []
reviews = self.review_data

if review_id:
reviews = reviews[reviews['review_id'] == review_id]
else:
business_id = business_id or (self.task.get('business_id') if self.task else None)
user_id = user_id or (self.task.get('user_id') if self.task else None)
if business_id:
reviews = reviews[reviews['business_id'] == business_id]
if user_id:
reviews = reviews[reviews['user_id'] == user_id]
return reviews.to_dict(orient='records')

def get_tips(self, business_id: Optional[str] = None, user_id: Optional[str] = None) -> List[Dict]:
def get_tips(self, business_id: str, user_id: str) -> List[Dict]:
"""Fetch tips with date filter."""
self._ensure_task()
business_id = business_id or (self.task.get('business') if self.task else None)
user_id = user_id or (self.task.get('user') if self.task else None)
tips = self.tip_data
if business_id:
tips = tips[tips['business_id'] == business_id]
if user_id:
tips = tips[tips['user_id'] == user_id]
return tips.to_dict(orient='records')

def get_checkins(self, business_id: Optional[str] = None) -> List[Dict]:
def get_checkins(self, business_id: str) -> List[Dict]:
"""Fetch checkins with date filter."""
self._ensure_task()
business_id = business_id or (self.task.get('business') if self.task else None)
if not business_id:
return []
checkins = self.checkin_data
checkins = checkins[checkins['business_id'] == business_id]
return checkins.to_dict(orient='records')

0 comments on commit 057e403

Please sign in to comment.