-
Notifications
You must be signed in to change notification settings - Fork 330
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
from crawlee._utils.math import weighted_avg | ||
|
||
|
||
@pytest.mark.parametrize( | ||
('values', 'weights', 'expected'), | ||
[ | ||
([20, 40, 50], [2, 3, 5], 41), | ||
([1, 2, 3], [0.5, 0.25, 0.25], 1.75), | ||
([4, 4, 4], [1, 0, 1], 4.0), | ||
([1, 2, 3], [0.33, 0.33, 0.33], 2), | ||
([1, 2, 3], [0.2, -0.3, 0.5], 2.75), | ||
], | ||
ids=['basic', 'fractional_weights', 'zero_weight', 'all_equal_weights', 'negative_weights'], | ||
) | ||
def test__weighted_avg__basic(values: list[float], weights: list[float], expected: float) -> None: | ||
assert weighted_avg(values, weights) == expected | ||
|
||
|
||
def test__weighted_avg__empty() -> None: | ||
values: list[float] = [] | ||
weights: list[float] = [] | ||
with pytest.raises(ValueError, match='Values and weights lists must not be empty'): | ||
weighted_avg(values, weights) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
('values', 'weights'), | ||
[ | ||
([3, 2], [10]), | ||
([2], [1, 5, 7]), | ||
], | ||
) | ||
def test__weighted_avg__unequal_length_lists(values: list[float], weights: list[float]) -> None: | ||
with pytest.raises(ValueError, match='Values and weights must be of equal length'): | ||
weighted_avg(values, weights) | ||
|
||
|
||
def test__weighted_avg__zero_total_weight() -> None: | ||
values: list[float] = [1, 2, 3] | ||
weights: list[float] = [0, 0, 0] | ||
with pytest.raises(ValueError, match='Total weight cannot be zero'): | ||
weighted_avg(values, weights) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from datetime import timedelta | ||
from unittest.mock import AsyncMock | ||
|
||
import pytest | ||
|
||
from crawlee._utils.recurring_task import RecurringTask | ||
|
||
|
||
@pytest.fixture() | ||
def function() -> AsyncMock: | ||
mock_function = AsyncMock() | ||
mock_function.__name__ = 'mocked_function' # To avoid issues with the function name in RecurringTask | ||
return mock_function | ||
|
||
|
||
@pytest.fixture() | ||
def delay() -> timedelta: | ||
return timedelta(milliseconds=30) | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test__recurring_task__init(function: AsyncMock, delay: timedelta) -> None: | ||
rt = RecurringTask(function, delay) | ||
assert rt.func == function | ||
assert rt.delay == delay | ||
assert rt.task is None | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test__recurring_task__start_and_stop(function: AsyncMock, delay: timedelta) -> None: | ||
rt = RecurringTask(function, delay) | ||
|
||
rt.start() | ||
await asyncio.sleep(0) # Yield control to allow the task to start | ||
|
||
assert isinstance(rt.task, asyncio.Task) | ||
assert not rt.task.done() | ||
|
||
await rt.stop() | ||
assert rt.task.done() | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test__recurring_task__execution(function: AsyncMock, delay: timedelta) -> None: | ||
task = RecurringTask(function, delay) | ||
|
||
task.start() | ||
await asyncio.sleep(0.1) # Wait enough for the task to execute a few times | ||
await task.stop() | ||
|
||
assert isinstance(task.func, AsyncMock) # To let MyPy know that the function is a mocked | ||
assert task.func.call_count >= 2 | ||
|
||
await task.stop() |