Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: cross validate timings #233

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions rectools/model_selection/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import typing as tp
from contextlib import contextmanager

from rectools.columns import Columns
from rectools.dataset import Dataset
Expand All @@ -24,6 +25,26 @@
from .splitter import Splitter


@contextmanager
def compute_timing(label: str, timings: tp.Optional[tp.Dict[str, float]] = None) -> tp.Iterator[None]:
"""
Context manager to compute timing for a code block.

Parameters
----------
label : str
Label to store the timing result in the timings dictionary.
timings : dict, optional
Dictionary to store the timing results. If None, timing is not recorded.
"""
if timings is not None:
start_time = time.time()
yield
timings[label] = round(time.time() - start_time, 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we needed to round not in the actual code but in the tests below. we are adding rounding just to pass the tests. so it shouldn't affect the actual code in framework.
let's round to 5 digits here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
yield


def cross_validate( # pylint: disable=too-many-locals
dataset: Dataset,
splitter: Splitter,
Expand All @@ -36,6 +57,7 @@ def cross_validate( # pylint: disable=too-many-locals
ref_models: tp.Optional[tp.List[str]] = None,
validate_ref_models: bool = False,
on_unsupported_targets: ErrorBehaviour = "warn",
compute_timings: bool = False,
) -> tp.Dict[str, tp.Any]:
"""
Run cross validation on multiple models with multiple metrics.
Expand Down Expand Up @@ -123,28 +145,16 @@ def cross_validate( # pylint: disable=too-many-locals

# ### Train ref models if any
ref_reco = {}
ref_timings = {}
for model_name in ref_models or []:
model = models[model_name]
model.fit(fold_dataset)
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)
model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue

if model_name in ref_reco:
reco = ref_reco[model_name]
else:
with compute_timing("fit_time", model_timings):
model.fit(fold_dataset)
reco = model.recommend(

with compute_timing("recommend_time", model_timings):
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
Expand All @@ -153,6 +163,33 @@ def cross_validate( # pylint: disable=too-many-locals
on_unsupported_targets=on_unsupported_targets,
)

ref_timings[model_name] = model_timings or {}

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue
if model_name in ref_reco:
reco = ref_reco[model_name]
model_timing = ref_timings[model_name]
else:
model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None # type: ignore

with compute_timing("fit_time", model_timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", model_timings):
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)

model_timing = model_timings or {}

metric_values = calc_metrics(
metrics,
reco=reco,
Expand All @@ -163,6 +200,7 @@ def cross_validate( # pylint: disable=too-many-locals
)
res = {"model": model_name, "i_split": split_info["i_split"]}
res.update(metric_values)
res.update(model_timing)
metrics_all.append(res)

result = {"splits": split_infos, "metrics": metrics_all}
Expand Down
242 changes: 242 additions & 0 deletions tests/model_selection/test_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,247 @@ def test_happy_path_with_intersection(
],
"metrics": expected_metrics,
}
assert actual == expected

@pytest.mark.parametrize(
"ref_models,validate_ref_models,expected_metrics,compute_timings",
(
(
["popular"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep only ["popular"] and not put ref_models in parametrize.
let's iterate over validate_ref_models and compute_timings . only 4 test cases are needed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

False,
[
{
"model": "random",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.0,
"intersection_popular": 0.5,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "random",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.5,
"intersection_popular": 0.75,
"fit_time": 0.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's drop timings from expected dicts.
and keep thresholds comparison.
just pop timings from the actual dict

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"recommend_time": 0.0,
},
],
True,
),
(
["popular"],
True,
[
{
"model": "popular",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.5,
"intersection_popular": 1.0,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "random",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.0,
"intersection_popular": 0.5,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "popular",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.25,
"intersection_popular": 1.0,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "random",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.5,
"intersection_popular": 0.75,
"fit_time": 0.0,
"recommend_time": 0.0,
},
],
True,
),
(
["random"],
False,
[
{
"model": "popular",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.5,
"intersection_random": 0.5,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "popular",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.25,
"intersection_random": 0.75,
"fit_time": 0.0,
"recommend_time": 0.0,
},
],
True,
),
(
["random"],
True,
[
{
"model": "popular",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.5,
"intersection_random": 0.5,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "random",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.0,
"intersection_random": 1.0,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "popular",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.25,
"intersection_random": 0.75,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "random",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.5,
"intersection_random": 1.0,
"fit_time": 0.0,
"recommend_time": 0.0,
},
],
True,
),
(["random", "popular"], False, [], True),
(
["random", "popular"],
True,
[
{
"model": "popular",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.5,
"intersection_random": 0.5,
"intersection_popular": 1.0,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "random",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.0,
"intersection_random": 1.0,
"intersection_popular": 0.5,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "popular",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.25,
"intersection_random": 0.75,
"intersection_popular": 1.0,
"fit_time": 0.0,
"recommend_time": 0.0,
},
{
"model": "random",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.5,
"intersection_random": 1.0,
"intersection_popular": 0.75,
"fit_time": 0.0,
"recommend_time": 0.0,
},
],
True,
),
),
)
def test_happy_path_with_intersection_timings(
self,
ref_models: tp.Optional[tp.List[str]],
validate_ref_models: bool,
expected_metrics: tp.List[tp.Dict[str, tp.Any]],
compute_timings: bool,
) -> None:
splitter = LastNSplitter(n=1, n_splits=2, filter_cold_items=False, filter_already_seen=False)

actual = cross_validate(
dataset=self.dataset,
splitter=splitter,
metrics=self.metrics_intersection,
models=self.models,
k=2,
filter_viewed=False,
ref_models=ref_models,
validate_ref_models=validate_ref_models,
compute_timings=compute_timings,
)

time_threshold = 0.5

for data in actual["metrics"]:
print(data["fit_time"])
print(data["recommend_time"])
assert data["fit_time"] < time_threshold
assert data["recommend_time"] < time_threshold

expected = {
"splits": [
{
"i_split": 0,
"test": 2,
"test_items": 2,
"test_users": 2,
"train": 2,
"train_items": 2,
"train_users": 2,
},
{
"i_split": 1,
"test": 4,
"test_items": 3,
"test_users": 4,
"train": 6,
"train_items": 2,
"train_users": 4,
},
],
"metrics": expected_metrics,
}
assert actual == expected