Skip to content

Commit

Permalink
Add stock ticker dataset (#303)
Browse files Browse the repository at this point in the history
* Add stock ticker dataset

* Add multi-table sales order dataset

* Fix code formatting

* Fix code formatting

* Add min and max lat-lon options

---------

Co-authored-by: Ronan Stokes <[email protected]>
  • Loading branch information
ghanse and ronanstokes-db authored Dec 19, 2024
1 parent a9a7778 commit b4370c0
Show file tree
Hide file tree
Showing 5 changed files with 896 additions and 51 deletions.
4 changes: 4 additions & 0 deletions dbldatagen/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from .dataset_provider import DatasetProvider, dataset_definition
from .basic_geometries import BasicGeometriesProvider
from .basic_process_historian import BasicProcessHistorianProvider
from .basic_stock_ticker import BasicStockTickerProvider
from .basic_telematics import BasicTelematicsProvider
from .basic_user import BasicUserProvider
from .benchmark_groupby import BenchmarkGroupByProvider
from .multi_table_sales_order_provider import MultiTableSalesOrderProvider
from .multi_table_telephony_provider import MultiTableTelephonyProvider

__all__ = ["dataset_provider",
"basic_geometries",
"basic_process_historian",
"basic_stock_ticker",
"basic_telematics",
"basic_user",
"benchmark_groupby",
"multi_table_sales_order_provider",
"multi_table_telephony_provider"
]
24 changes: 18 additions & 6 deletions dbldatagen/datasets/basic_geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
"""
MIN_LOCATION_ID = 1000000
MAX_LOCATION_ID = 9223372036854775807
DEFAULT_MIN_LAT = -90.0
DEFAULT_MAX_LAT = 90.0
DEFAULT_MIN_LON = -180.0
DEFAULT_MAX_LON = 180.0
COLUMN_COUNT = 2
ALLOWED_OPTIONS = [
"geometryType",
"maxVertices",
"minLatitude",
"maxLatitude",
"minLongitude",
"maxLongitude",
"random"
]

Expand All @@ -45,6 +53,10 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
generateRandom = options.get("random", False)
geometryType = options.get("geometryType", "point")
maxVertices = options.get("maxVertices", 1 if geometryType == "point" else 3)
minLatitude = options.get("minLatitude", self.DEFAULT_MIN_LAT)
maxLatitude = options.get("maxLatitude", self.DEFAULT_MAX_LAT)
minLongitude = options.get("minLongitude", self.DEFAULT_MIN_LON)
maxLongitude = options.get("maxLongitude", self.DEFAULT_MAX_LON)

assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name"
if rows is None or rows < 0:
Expand All @@ -62,9 +74,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
if maxVertices > 1:
w.warn('Ignoring property maxVertices for point geometries')
df_spec = (
df_spec.withColumn("lat", "float", minValue=-90.0, maxValue=90.0,
df_spec.withColumn("lat", "float", minValue=minLatitude, maxValue=maxLatitude,
step=1e-5, random=generateRandom, omit=True)
.withColumn("lon", "float", minValue=-180.0, maxValue=180.0,
.withColumn("lon", "float", minValue=minLongitude, maxValue=maxLongitude,
step=1e-5, random=generateRandom, omit=True)
.withColumn("wkt", "string", expr="concat('POINT(', lon, ' ', lat, ')')")
)
Expand All @@ -75,9 +87,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
j = 0
while j < maxVertices:
df_spec = (
df_spec.withColumn(f"lat_{j}", "float", minValue=-90.0, maxValue=90.0,
df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude,
step=1e-5, random=generateRandom, omit=True)
.withColumn(f"lon_{j}", "float", minValue=-180.0, maxValue=180.0,
.withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude,
step=1e-5, random=generateRandom, omit=True)
)
j = j + 1
Expand All @@ -93,9 +105,9 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
j = 0
while j < maxVertices:
df_spec = (
df_spec.withColumn(f"lat_{j}", "float", minValue=-90.0, maxValue=90.0,
df_spec.withColumn(f"lat_{j}", "float", minValue=minLatitude, maxValue=maxLatitude,
step=1e-5, random=generateRandom, omit=True)
.withColumn(f"lon_{j}", "float", minValue=-180.0, maxValue=180.0,
.withColumn(f"lon_{j}", "float", minValue=minLongitude, maxValue=maxLongitude,
step=1e-5, random=generateRandom, omit=True)
)
j = j + 1
Expand Down
103 changes: 103 additions & 0 deletions dbldatagen/datasets/basic_stock_ticker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from random import random

from .dataset_provider import DatasetProvider, dataset_definition


@dataset_definition(name="basic/stock_ticker",
summary="Stock ticker dataset",
autoRegister=True,
supportsStreaming=True)
class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvider):
"""
Basic Stock Ticker Dataset
========================
This is a basic stock ticker dataset with time-series `symbol`, `open`, `close`, `high`, `low`,
`adj_close`, and `volume` values.
It takes the following options when retrieving the table:
- rows : number of rows to generate
- partitions: number of partitions to use
- numSymbols: number of unique stock ticker symbols
- startDate: first date for stock ticker data
- endDate: last date for stock ticker data
As the data specification is a DataGenerator object, you can add further columns to the data set and
add constraints (when the feature is available)
Note that this dataset does not use any features that would prevent it from being used as a source for a
streaming dataframe, and so the flag `supportsStreaming` is set to True.
"""
DEFAULT_NUM_SYMBOLS = 100
DEFAULT_START_DATE = "2024-10-01"
COLUMN_COUNT = 8
ALLOWED_OPTIONS = [
"numSymbols",
"startDate"
]

@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions=-1, **options):
import dbldatagen as dg

numSymbols = options.get("numSymbols", self.DEFAULT_NUM_SYMBOLS)
startDate = options.get("startDate", self.DEFAULT_START_DATE)

assert tableName is None or tableName == DatasetProvider.DEFAULT_TABLE_NAME, "Invalid table name"
if rows is None or rows < 0:
rows = DatasetProvider.DEFAULT_ROWS
if partitions is None or partitions < 0:
partitions = self.autoComputePartitions(rows, self.COLUMN_COUNT)
if numSymbols <= 0:
raise ValueError("'numSymbols' must be > 0")

df_spec = (
dg.DataGenerator(sparkSession=sparkSession, rows=rows,
partitions=partitions, randomSeedMethod="hash_fieldname")
.withColumn("symbol_id", "long", minValue=676, maxValue=676 + numSymbols - 1)
.withColumn("rand_value", "float", minValue=0.0, maxValue=1.0, step=0.1,
baseColumn="symbol_id", omit=True)
.withColumn("symbol", "string",
expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''),
x -> case when x < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""")
.withColumn("days_from_start_date", "int", expr=f"floor(id / {numSymbols})", omit=True)
.withColumn("post_date", "date", expr=f"date_add(cast('{startDate}' as date), days_from_start_date)")
.withColumn("start_value", "decimal(11,2)",
values=[1.0 + 199.0 * random() for _ in range(int(numSymbols / 10))], omit=True)
.withColumn("growth_rate", "float", values=[-0.1 + 0.35 * random() for _ in range(int(numSymbols / 10))],
baseColumn="symbol_id")
.withColumn("volatility", "float", values=[0.0075 * random() for _ in range(int(numSymbols / 10))],
baseColumn="symbol_id", omit=True)
.withColumn("prev_modifier_sign", "float",
expr=f"case when sin((id - {numSymbols}) % 17) > 0 then -1.0 else 1.0 end""",
omit=True)
.withColumn("modifier_sign", "float",
expr="case when sin(id % 17) > 0 then -1.0 else 1.0 end",
omit=True)
.withColumn("open_base", "decimal(11,2)",
expr=f"""start_value
+ (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17))
+ (growth_rate * start_value * (days_from_start_date - 1) / 365)""",
omit=True)
.withColumn("close_base", "decimal(11,2)",
expr="""start_value
+ (volatility * start_value * sin(id % 17))
+ (growth_rate * start_value * days_from_start_date / 365)""",
omit=True)
.withColumn("high_base", "decimal(11,2)",
expr="greatest(open_base, close_base) + rand() * volatility * open_base",
omit=True)
.withColumn("low_base", "decimal(11,2)",
expr="least(open_base, close_base) - rand() * volatility * open_base",
omit=True)
.withColumn("open", "decimal(11,2)", expr="greatest(open_base, 0.0)")
.withColumn("close", "decimal(11,2)", expr="greatest(close_base, 0.0)")
.withColumn("high", "decimal(11,2)", expr="greatest(high_base, 0.0)")
.withColumn("low", "decimal(11,2)", expr="greatest(low_base, 0.0)")
.withColumn("dividend", "decimal(4,2)", expr="0.05 * rand_value * close", omit=True)
.withColumn("adj_close", "decimal(11,2)", expr="greatest(close - dividend, 0.0)")
.withColumn("volume", "long", minValue=100000, maxValue=5000000, random=True)
)

return df_spec
Loading

0 comments on commit b4370c0

Please sign in to comment.