Skip to content

Commit

Permalink
Fix data ranges for random data generation (#298)
Browse files Browse the repository at this point in the history
* Set default ranges when values are unspecified

* Add tests

* Get numeric datatype ranges via a static method

* Add tests for better coverage of NRange
  • Loading branch information
ghanse authored Dec 19, 2024
1 parent f99882d commit a9a7778
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 8 deletions.
28 changes: 21 additions & 7 deletions dbldatagen/nrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math

from pyspark.sql.types import LongType, FloatType, IntegerType, DoubleType, ShortType, \
ByteType
ByteType, DecimalType

from .datarange import DataRange

Expand Down Expand Up @@ -83,13 +83,12 @@ def adjustForColumnDatatype(self, ctype):
:param ctype: Spark SQL type instance to adjust range for
:returns: No return value - executes for effect only
"""
if ctype.typeName() == 'decimal':
numeric_types = [DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType]
if type(ctype) in numeric_types:
if self.minValue is None:
self.minValue = 0.0
self.minValue = NRange._getNumericDataTypeRange(ctype)[0]
if self.maxValue is None:
self.maxValue = math.pow(10, ctype.precision - ctype.scale) - 1.0
if self.step is None:
self.step = 1.0
self.maxValue = NRange._getNumericDataTypeRange(ctype)[1]

if type(ctype) is ShortType and self.maxValue is not None:
assert self.maxValue <= 65536, "`maxValue` must be in range of short"
Expand Down Expand Up @@ -145,7 +144,8 @@ def getScale(self):
# return maximum scale of components
return max(smin, smax, sstep)

def _precision_and_scale(self, x):
@staticmethod
def _precision_and_scale(x):
max_digits = 14
int_part = int(abs(x))
magnitude = 1 if int_part == 0 else int(math.log10(int_part)) + 1
Expand All @@ -158,3 +158,17 @@ def _precision_and_scale(self, x):
frac_digits /= 10
scale = int(math.log10(frac_digits))
return (magnitude + scale, scale)

@staticmethod
def _getNumericDataTypeRange(ctype):
value_ranges = {
ByteType: (0, (2 ** 4 - 1)),
ShortType: (0, (2 ** 8 - 1)),
IntegerType: (0, (2 ** 16 - 1)),
LongType: (0, (2 ** 32 - 1)),
FloatType: (0.0, 3.402e38),
DoubleType: (0.0, 1.79769e308)
}
if type(ctype) is DecimalType:
return 0.0, math.pow(10, ctype.precision - ctype.scale) - 1.0
return value_ranges.get(type(ctype), None)
19 changes: 19 additions & 0 deletions tests/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,25 @@ def test_random2(self):
colSpec3 = ds.getColumnSpec("code3")
assert colSpec3.random is True

def test_random3(self):
# will have implied column `id` for ordinal of row
ds = (
dg.DataGenerator(sparkSession=spark, name="test_dataset1", rows=500, partitions=1, random=True)
.withIdOutput()
.withColumn("val1", "decimal(5,2)", maxValue=20.0, step=0.01, random=True)
.withColumn("val2", "float", maxValue=20.0, random=True)
.withColumn("val3", "double", maxValue=20.0, random=True)
.withColumn("val4", "byte", maxValue=15, random=True)
.withColumn("val5", "short", maxValue=31, random=True)
.withColumn("val6", "integer", maxValue=63, random=True)
.withColumn("val7", "long", maxValue=127, random=True)
)

df = ds.build()
cols = ["val1", "val2", "val3", "val4", "val5", "val6", "val7"]
for col in cols:
assert df.collect() != df.orderBy(col).collect(), f"Random values were not generated for {col}"

def test_random_multiple_columns(self):
# will have implied column `id` for ordinal of row
ds = (
Expand Down
58 changes: 57 additions & 1 deletion tests/test_quick_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from datetime import timedelta, datetime

import pytest
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType, DateType
from pyspark.sql.types import (
StructType, StructField, IntegerType, StringType, FloatType, DateType, DecimalType, DoubleType, ByteType,
ShortType, LongType
)


import dbldatagen as dg
from dbldatagen import DataGenerator
Expand Down Expand Up @@ -403,6 +407,28 @@ def test_basic_prefix(self):
rowCount = formattedDF.count()
assert rowCount == 1000

def test_missing_range_values(self):
column_types = [FloatType(), DoubleType(), ByteType(), ShortType(), IntegerType(), LongType()]
for column_type in column_types:
range_no_min = NRange(maxValue=1.0)
range_no_max = NRange(minValue=0.0)
range_no_min.adjustForColumnDatatype(column_type)
assert range_no_min.min == NRange._getNumericDataTypeRange(column_type)[0]
assert range_no_min.step == 1
range_no_max.adjustForColumnDatatype(column_type)
assert range_no_max.max == NRange._getNumericDataTypeRange(column_type)[1]
assert range_no_max.step == 1

def test_range_with_until(self):
range_until = NRange(step=2, until=100)
range_until.adjustForColumnDatatype(IntegerType())
assert range_until.minValue == 0
assert range_until.maxValue == 101

def test_empty_range(self):
empty_range = NRange()
assert empty_range.isEmpty()

def test_reversed_ranges(self):
testDataSpec = (dg.DataGenerator(sparkSession=spark, name="ranged_data", rows=100000,
partitions=4)
Expand Down Expand Up @@ -695,6 +721,36 @@ def test_strings_from_numeric_string_field4(self):
rowCount = nullRowsDF.count()
assert rowCount == 0

@pytest.mark.parametrize("columnSpecOptions", [
{"dataType": "byte", "minValue": 1, "maxValue": None},
{"dataType": "byte", "minValue": None, "maxValue": 10},
{"dataType": "short", "minValue": 1, "maxValue": None},
{"dataType": "short", "minValue": None, "maxValue": 100},
{"dataType": "integer", "minValue": 1, "maxValue": None},
{"dataType": "integer", "minValue": None, "maxValue": 100},
{"dataType": "long", "minValue": 1, "maxValue": None},
{"dataType": "long", "minValue": None, "maxValue": 100},
{"dataType": "float", "minValue": 1.0, "maxValue": None},
{"dataType": "float", "minValue": None, "maxValue": 100.0},
{"dataType": "double", "minValue": 1, "maxValue": None},
{"dataType": "double", "minValue": None, "maxValue": 100.0}
])
def test_random_generation_without_range_values(self, columnSpecOptions):
dataType = columnSpecOptions.get("dataType", None)
minValue = columnSpecOptions.get("minValue", None)
maxValue = columnSpecOptions.get("maxValue", None)
testDataSpec = (dg.DataGenerator(sparkSession=spark, name="randomGenerationWithoutRangeValues", rows=100,
partitions=4)
.withIdOutput()
# default column type is string
.withColumn("randCol", colType=dataType, minValue=minValue, maxValue=maxValue, random=True)
)

df = testDataSpec.build(withTempView=True)
sortedDf = df.orderBy("randCol")
sortedVals = sortedDf.select("randCol").collect()
assert sortedVals != df.select("randCol").collect()

def test_version_info(self):
# test access to version info without explicit import
print("Data generator version", dg.__version__)

0 comments on commit a9a7778

Please sign in to comment.