diff --git a/examples/movielens/evaluation.ipynb b/examples/movielens/evaluation.ipynb new file mode 100644 index 0000000..3d2a450 --- /dev/null +++ b/examples/movielens/evaluation.ipynb @@ -0,0 +1,428247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Apply Dirichlet Thompson Sampling on the Movielens Dataset\n", + "The goal is to maximize the number of relevant movie suggestions using `IndependentBandits` and `DirichletThompsonSampling`.\n", + "A rating of 4 and above is considered as positive feedback for the recommendation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tqdm import tqdm\n", + "from io import BytesIO\n", + "from zipfile import ZipFile\n", + "import matplotlib.pyplot as plt\n", + "from collections import Counter\n", + "from urllib.request import urlopen\n", + "\n", + "from mab_ranking.bandits.rank_bandits import IndependentBandits\n", + "from mab_ranking.bandits.bandits import BetaThompsonSampling, DirichletThompsonSampling\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Helper functions\n", + "\n", + "\n", + "def get_movielens_data(url):\n", + " \"\"\"\n", + " Get the movielens data\n", + " :param url: [str], url name\n", + " \"\"\"\n", + " data = []\n", + " resp = urlopen(url)\n", + " zpfile = ZipFile(BytesIO(resp.read()))\n", + " for line in zpfile.open('ml-100k/u1.base').read().splitlines():\n", + " data.append(\n", + " [int(float(x)) for x in line.decode('utf-8').split('\\t')])\n", + "\n", + " return np.asarray(data)\n", + "\n", + "\n", + "def filter_data(data):\n", + " \"\"\"\n", + " Filters the data by keeping rating greater or equal than 4.\n", + " Removes events that contain movies that don't belong in the top 100\n", + " \n", + " :param data: list[int], events data\n", + " \"\"\"\n", + " data_pos_rating = np.asarray([x for x in data if x[2] >= 4])\n", + " movies = [x[1] for x in data_pos_rating]\n", + " freq = Counter(movies)\n", + " keep_movies = [x[0] for x in freq.most_common(100)]\n", + "\n", + " filtered_data = []\n", + " for d in data_pos_rating:\n", + " if d[1] in keep_movies:\n", + " filtered_data.append(d)\n", + "\n", + " return np.asarray(filtered_data)\n", + "\n", + "\n", + "def index_data(data):\n", + " \"\"\"\n", + " Indexed users and movies\n", + " \n", + " :param data: list[int], events data\n", + " \"\"\"\n", + " user_indexer = {v: i for i, v in enumerate(set([d[0] for d in data]))}\n", + " movie_indexer = {v: i for i, v in enumerate(set([d[1] for d in data]))}\n", + " indexed_data = []\n", + " for d in data:\n", + " d[0] = user_indexer[d[0]]\n", + " d[1] = movie_indexer[d[1]]\n", + " indexed_data.append(d)\n", + " return np.asarray(indexed_data), user_indexer, movie_indexer\n", + "\n", + "\n", + "def running_mean(x, n):\n", + " \"\"\"\n", + " Calculates the running mean\n", + " \"\"\"\n", + " cumsum = np.cumsum(np.insert(x, 0, 0)) \n", + " return (cumsum[n:] - cumsum[:-n]) / float(n)\n", + "\n", + "\n", + "def plot_ctr(num_iterations, ctr):\n", + " \"\"\"\n", + " Plots the CTR over time\n", + " \n", + " :param num_iterations: [int], number or iterations\n", + " :param ctr: list[float], ctrs over each time step\n", + " \"\"\"\n", + " plt.plot(range(1, num_iterations + 1), ctr)\n", + " plt.xlabel('num_iterations', fontsize=14)\n", + " plt.ylabel('ctr', fontsize=14)\n", + " return plt\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "url = 'http://files.grouplens.org/datasets/movielens/ml-100k.zip'\n", + "\n", + "# get the movielens data\n", + "data = get_movielens_data(url)\n", + "\n", + "# filter the data\n", + "filtered_data = filter_data(data)\n", + "\n", + "# index the data\n", + "indexed_data, user_indexer, movie_indexer = index_data(filtered_data)\n", + "\n", + "# sort the data by timestamp\n", + "indexed_data_sorted = sorted(indexed_data, key=lambda x: x[3])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def experiment(indexed_data_sorted, independent_bandits):\n", + " # instantiate the dictionary that would hold all the past actions of each user\n", + " # initial state is 0\n", + " actions_dict = {}\n", + " for i in indexed_data_sorted:\n", + " if i[0] not in actions_dict:\n", + " actions_dict[i[0]] = [0]\n", + "\n", + " sum_binary = 0.0\n", + " ctr_list = []\n", + " k = 100\n", + " i = 1\n", + " for d in tqdm(indexed_data_sorted):\n", + " ground_truth = [d[1]]\n", + "\n", + " selected_items = independent_bandits.choose(context={'previous_action': actions_dict[d[0]][-1]})\n", + " actions_dict[d[0]] += ground_truth\n", + " hit_rate = len(set(ground_truth).intersection(set(selected_items))) / len(set(ground_truth))\n", + "\n", + " feedback_list = [1.0 if _item in ground_truth else 0.0 for _item in selected_items]\n", + " independent_bandits.update(selected_items, feedback_list)\n", + "\n", + " binary_relevancy = 1.0 if hit_rate > 0 else 0.0\n", + " sum_binary += binary_relevancy\n", + " ctr_list.append(sum_binary / i)\n", + " i += 1\n", + "\n", + " ctr_avg = running_mean(ctr_list, k)\n", + "\n", + " return ctr_avg\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Empirical Results for Dirichlet-based Approach" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r 0%| | 0/16635 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_ctr(len(result), result).show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Empirical Results for IBA with BetaTS" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r 0%| | 0/16635 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_ctr(len(result_betats), result_betats).show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}