Skip to content

Commit

Permalink
fix: respect both mulitple_of and minimum/maximum constraints
Browse files Browse the repository at this point in the history
Previously, `generate_constrained_number()` would potentially generate
invalid numbers when `mulitple_of` is not None and exactly one of either
`minimum` or `maximum` is not None, since it would just return
`mulitple_of` without respecting the upper or lower bound.

This significantly changes the implementation of
`generate_constrained_number()` in an attempt to handle this case. We
now first check for the presence of `mulitple_of`, and if it is None,
then we return early by delegating to the `method` parameter.

Otherwise, we first generate a random number with `method`, and then we
attempt to find the nearest number that is a proper multiple of
`mulitple_of`. Most of the newly added complexity of the function is
meant to handle floating-point or Decimal precision issues, and in
worst-case scenarios, it may still not be capable of finding a legal
value.
  • Loading branch information
richardxia committed Feb 26, 2024
1 parent c4e3d91 commit a65482d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 24 deletions.
62 changes: 55 additions & 7 deletions polyfactory/value_generators/constrained_numbers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import decimal
from decimal import Decimal
from math import ceil, floor, isinf
from sys import float_info
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast

Expand Down Expand Up @@ -227,16 +229,62 @@ def generate_constrained_number(
:returns: A value of type T.
"""
if minimum is None or maximum is None:
return multiple_of if multiple_of is not None else method(random=random)
if multiple_of is None:
return method(random=random, minimum=minimum, maximum=maximum)
if multiple_of >= minimum:

def passes_all_constraints(value: T) -> bool:
return (
(minimum is None or value >= minimum)
and (maximum is None or value <= maximum)
and (multiple_of is None or passes_pydantic_multiple_validator(value, multiple_of))
)

# If the arguments are Decimals, they might have precision that is greater than the current decimal context. If
# so, recreate them under the current context to ensure they have the appropriate precision. This is important
# because otherwise, x * 1 == x may not always hold, which can cause the algorithm below to fail in unintuitive
# ways.
if isinstance(minimum, Decimal):
minimum = decimal.getcontext().create_decimal(minimum)
if isinstance(maximum, Decimal):
maximum = decimal.getcontext().create_decimal(maximum)
if isinstance(multiple_of, Decimal):
multiple_of = decimal.getcontext().create_decimal(multiple_of)

max_attempts = 10
for _ in range(max_attempts):
# We attempt to generate a random number and find the nearest valid multiple, but a naive approach of rounding
# to the nearest multiple may push the number out of range. To handle edge cases, we find both the nearest
# multiple in both the negative and positive directions (floor and ceil), and we pick one that fits within
# range. We should be guaranteed to find a number other than in the case where the range (minimum, maximum) is
# narrow and does not contain any multiple of multiple_of.
random_value = method(random=random, minimum=minimum, maximum=maximum)
quotient = random_value / multiple_of
if isinf(quotient):
continue
lower = floor(quotient) * multiple_of
upper = ceil(quotient) * multiple_of

# If both the lower and upper candidates are out of bounds, then there are no valid multiples that fit within
# the specified range.
if minimum is not None and maximum is not None and lower < minimum and upper > maximum:
msg = f"no multiple of {multiple_of} exists between {minimum} and {maximum}"
raise ParameterException(msg)

for candidate in [lower, upper]:
if not passes_all_constraints(candidate):
continue
return candidate

# Try last-ditch attempt at using the multiple_of, 0, or -multiple_of as the value
if passes_all_constraints(multiple_of):
return multiple_of
result = minimum
while not passes_pydantic_multiple_validator(result, multiple_of):
result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of
return result
if passes_all_constraints(-multiple_of):
return -multiple_of
if passes_all_constraints(multiple_of * 0):
return multiple_of * 0

msg = f"could not find solution in {max_attempts} attempts"
raise ValueError(msg)


def handle_constrained_int(
Expand Down
36 changes: 27 additions & 9 deletions tests/constraints/test_decimal_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, cast

import pytest
from hypothesis import given
from hypothesis import assume, given
from hypothesis.strategies import decimals, integers

from pydantic import BaseModel, condecimal
Expand Down Expand Up @@ -239,19 +239,23 @@ def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, v
decimals(
allow_nan=False,
allow_infinity=False,
min_value=-1000000000,
max_value=1000000000,
min_value=-100000000,
max_value=100000000,
),
decimals(
allow_nan=False,
allow_infinity=False,
min_value=-1000000000,
max_value=1000000000,
min_value=-100000000,
max_value=100000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, val2: Decimal) -> None:
min_value, multiple_of = sorted([val1, val2])
if multiple_of != Decimal("0"):
# When multiple_of is too many orders of magnitude smaller than min_value, then floating-point precision issues
# prevent us from constructing a number that can pass passes_pydantic_multiple_validator(). This scenario is
# very unlikely to occur in practice, so we tell Hypothesis to not generate these cases.
assume(abs(min_value / multiple_of) < Decimal("1e8"))
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
Expand All @@ -267,23 +271,37 @@ def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, v
)


# Note: The magnitudes of the min and max values have been specifically chosen to avoid issues with floating-point
# rounding errors. Despite these tests using Decimal numbers, the function under test will convert them to floats when
# calling `passes_pydantic_multiple_validator()`. Because `passes_pydantic_multiple_validator()` uses the modulus
# operator (%) with a fixed modulo of 1.0, we actually have to care about the absolute rounding error, not the relative
# error. IEEE 754 double-precision floating-point numbers are guaranteed to have at least 15 decimal digits of
# significand and up to 17 decimal digits of significant. `passes_pydantic_multiple_validator()` requires that the
# remainder modulo 1.0 be within 1e-8 of 0.0 or 1.0. Therefore, we can support a maximum value of approximately 10**(15
# - 8) = 10**7. We have some probabilistic buffer, so can set a maximum value of 10**8 and expect the tests to pass with
# reasonable confidence.
@given(
decimals(
allow_nan=False,
allow_infinity=False,
min_value=-1000000000,
max_value=1000000000,
min_value=-100000000,
max_value=100000000,
),
decimals(
allow_nan=False,
allow_infinity=False,
min_value=-1000000000,
max_value=1000000000,
min_value=-100000000,
max_value=100000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, val2: Decimal) -> None:
min_value, multiple_of = sorted([val1, val2])
if multiple_of != Decimal("0"):
# Despite the note above about choosing a max_value to avoid _absolute_ rounding errors, we also have to worry
# about _relative_ rounding errors between min_value and multiple_of. Once again,
# `passes_pydantic_multiple_validator()` requires that the remainder be no greater than 1e-8, so we tell
# Hypothesis not to generate cases where the min_value and multiple_of have a ratio greater than that.
assume(abs(min_value / multiple_of) < Decimal("1e8"))
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
Expand Down
54 changes: 46 additions & 8 deletions tests/test_number_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from decimal import Decimal, localcontext
from random import Random

import pytest
Expand All @@ -6,24 +7,61 @@
generate_constrained_number,
passes_pydantic_multiple_validator,
)
from polyfactory.value_generators.primitives import create_random_float
from polyfactory.value_generators.primitives import create_random_decimal, create_random_float


@pytest.mark.parametrize(
("maximum", "minimum", "multiple_of"),
((100, 2, 8), (-100, -187, -10), (7.55, 0.13, 0.0123)),
(
(100, 2, 8),
(-100, -187, -10),
(7.55, 0.13, 0.0123),
(None, 10, 3),
(None, -10, 3),
(13, 2, None),
(50, None, 7),
(-50, None, 7),
(None, None, 4),
(900, None, 1000),
),
)
def test_generate_constrained_number(maximum: float, minimum: float, multiple_of: float) -> None:
assert passes_pydantic_multiple_validator(
def test_generate_constrained_number(maximum: float | None, minimum: float | None, multiple_of: float | None) -> None:
value = generate_constrained_number(
random=Random(),
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
value=generate_constrained_number(
method=create_random_float,
)
if maximum is not None:
assert value <= maximum
if minimum is not None:
assert value >= minimum
if multiple_of is not None:
assert passes_pydantic_multiple_validator(multiple_of=multiple_of, value=value)


def test_generate_constrained_number_with_overprecise_decimals() -> None:
minimum = Decimal("1.0005")
maximum = Decimal("2")
multiple_of = Decimal("1.0005")

with localcontext() as ctx:
ctx.prec = 3

value = generate_constrained_number(
random=Random(),
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
method=create_random_float,
),
)
method=create_random_decimal,
)
if maximum is not None:
assert value <= ctx.create_decimal(maximum)
if minimum is not None:
assert value >= ctx.create_decimal(minimum)
if multiple_of is not None:
assert passes_pydantic_multiple_validator(multiple_of=ctx.create_decimal(multiple_of), value=value)


def test_passes_pydantic_multiple_validator_handles_zero_multiplier() -> None:
Expand Down

0 comments on commit a65482d

Please sign in to comment.