From a9a7778f38270f96f77c29ed743c717b235020c6 Mon Sep 17 00:00:00 2001 From: ghanse <163584195+ghanse@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:13:33 -0500 Subject: [PATCH] Fix data ranges for random data generation (#298) * Set default ranges when values are unspecified * Add tests * Get numeric datatype ranges via a static method * Add tests for better coverage of NRange --- dbldatagen/nrange.py | 28 ++++++++++++++----- tests/test_options.py | 19 +++++++++++++ tests/test_quick_tests.py | 58 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/dbldatagen/nrange.py b/dbldatagen/nrange.py index 863725e3..62e2da94 100644 --- a/dbldatagen/nrange.py +++ b/dbldatagen/nrange.py @@ -9,7 +9,7 @@ import math from pyspark.sql.types import LongType, FloatType, IntegerType, DoubleType, ShortType, \ - ByteType + ByteType, DecimalType from .datarange import DataRange @@ -83,13 +83,12 @@ def adjustForColumnDatatype(self, ctype): :param ctype: Spark SQL type instance to adjust range for :returns: No return value - executes for effect only """ - if ctype.typeName() == 'decimal': + numeric_types = [DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType] + if type(ctype) in numeric_types: if self.minValue is None: - self.minValue = 0.0 + self.minValue = NRange._getNumericDataTypeRange(ctype)[0] if self.maxValue is None: - self.maxValue = math.pow(10, ctype.precision - ctype.scale) - 1.0 - if self.step is None: - self.step = 1.0 + self.maxValue = NRange._getNumericDataTypeRange(ctype)[1] if type(ctype) is ShortType and self.maxValue is not None: assert self.maxValue <= 65536, "`maxValue` must be in range of short" @@ -145,7 +144,8 @@ def getScale(self): # return maximum scale of components return max(smin, smax, sstep) - def _precision_and_scale(self, x): + @staticmethod + def _precision_and_scale(x): max_digits = 14 int_part = int(abs(x)) magnitude = 1 if int_part == 0 else int(math.log10(int_part)) + 1 @@ -158,3 +158,17 @@ def _precision_and_scale(self, x): frac_digits /= 10 scale = int(math.log10(frac_digits)) return (magnitude + scale, scale) + + @staticmethod + def _getNumericDataTypeRange(ctype): + value_ranges = { + ByteType: (0, (2 ** 4 - 1)), + ShortType: (0, (2 ** 8 - 1)), + IntegerType: (0, (2 ** 16 - 1)), + LongType: (0, (2 ** 32 - 1)), + FloatType: (0.0, 3.402e38), + DoubleType: (0.0, 1.79769e308) + } + if type(ctype) is DecimalType: + return 0.0, math.pow(10, ctype.precision - ctype.scale) - 1.0 + return value_ranges.get(type(ctype), None) diff --git a/tests/test_options.py b/tests/test_options.py index 1fe7a620..78a595f8 100644 --- a/tests/test_options.py +++ b/tests/test_options.py @@ -218,6 +218,25 @@ def test_random2(self): colSpec3 = ds.getColumnSpec("code3") assert colSpec3.random is True + def test_random3(self): + # will have implied column `id` for ordinal of row + ds = ( + dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=500, partitions=1, random=True) + .withIdOutput() + .withColumn("val1", "decimal(5,2)", maxValue=20.0, step=0.01, random=True) + .withColumn("val2", "float", maxValue=20.0, random=True) + .withColumn("val3", "double", maxValue=20.0, random=True) + .withColumn("val4", "byte", maxValue=15, random=True) + .withColumn("val5", "short", maxValue=31, random=True) + .withColumn("val6", "integer", maxValue=63, random=True) + .withColumn("val7", "long", maxValue=127, random=True) + ) + + df = ds.build() + cols = ["val1", "val2", "val3", "val4", "val5", "val6", "val7"] + for col in cols: + assert df.collect() != df.orderBy(col).collect(), f"Random values were not generated for {col}" + def test_random_multiple_columns(self): # will have implied column `id` for ordinal of row ds = ( diff --git a/tests/test_quick_tests.py b/tests/test_quick_tests.py index 6c075173..de83daa3 100644 --- a/tests/test_quick_tests.py +++ b/tests/test_quick_tests.py @@ -1,7 +1,11 @@ from datetime import timedelta, datetime import pytest -from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, DateType +from pyspark.sql.types import ( + StructType, StructField, IntegerType, StringType, FloatType, DateType, DecimalType, DoubleType, ByteType, + ShortType, LongType +) + import dbldatagen as dg from dbldatagen import DataGenerator @@ -403,6 +407,28 @@ def test_basic_prefix(self): rowCount = formattedDF.count() assert rowCount == 1000 + def test_missing_range_values(self): + column_types = [FloatType(), DoubleType(), ByteType(), ShortType(), IntegerType(), LongType()] + for column_type in column_types: + range_no_min = NRange(maxValue=1.0) + range_no_max = NRange(minValue=0.0) + range_no_min.adjustForColumnDatatype(column_type) + assert range_no_min.min == NRange._getNumericDataTypeRange(column_type)[0] + assert range_no_min.step == 1 + range_no_max.adjustForColumnDatatype(column_type) + assert range_no_max.max == NRange._getNumericDataTypeRange(column_type)[1] + assert range_no_max.step == 1 + + def test_range_with_until(self): + range_until = NRange(step=2, until=100) + range_until.adjustForColumnDatatype(IntegerType()) + assert range_until.minValue == 0 + assert range_until.maxValue == 101 + + def test_empty_range(self): + empty_range = NRange() + assert empty_range.isEmpty() + def test_reversed_ranges(self): testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000, partitions=4) @@ -695,6 +721,36 @@ def test_strings_from_numeric_string_field4(self): rowCount = nullRowsDF.count() assert rowCount == 0 + @pytest.mark.parametrize("columnSpecOptions", [ + {"dataType": "byte", "minValue": 1, "maxValue": None}, + {"dataType": "byte", "minValue": None, "maxValue": 10}, + {"dataType": "short", "minValue": 1, "maxValue": None}, + {"dataType": "short", "minValue": None, "maxValue": 100}, + {"dataType": "integer", "minValue": 1, "maxValue": None}, + {"dataType": "integer", "minValue": None, "maxValue": 100}, + {"dataType": "long", "minValue": 1, "maxValue": None}, + {"dataType": "long", "minValue": None, "maxValue": 100}, + {"dataType": "float", "minValue": 1.0, "maxValue": None}, + {"dataType": "float", "minValue": None, "maxValue": 100.0}, + {"dataType": "double", "minValue": 1, "maxValue": None}, + {"dataType": "double", "minValue": None, "maxValue": 100.0} + ]) + def test_random_generation_without_range_values(self, columnSpecOptions): + dataType = columnSpecOptions.get("dataType", None) + minValue = columnSpecOptions.get("minValue", None) + maxValue = columnSpecOptions.get("maxValue", None) + testDataSpec = (dg.DataGenerator(sparkSession=spark, name="randomGenerationWithoutRangeValues", rows=100, + partitions=4) + .withIdOutput() + # default column type is string + .withColumn("randCol", colType=dataType, minValue=minValue, maxValue=maxValue, random=True) + ) + + df = testDataSpec.build(withTempView=True) + sortedDf = df.orderBy("randCol") + sortedVals = sortedDf.select("randCol").collect() + assert sortedVals != df.select("randCol").collect() + def test_version_info(self): # test access to version info without explicit import print("Data generator version", dg.__version__)