Skip to content

Commit

Permalink
Refactor: Address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: shamb0 <[email protected]>
  • Loading branch information
shamb0 committed Aug 25, 2024
1 parent 9491b82 commit e947e42
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 221 deletions.
59 changes: 0 additions & 59 deletions tests/common/mod.rs

This file was deleted.

98 changes: 0 additions & 98 deletions tests/common/print_utils.rs

This file was deleted.

84 changes: 25 additions & 59 deletions tests/fixtures/tables/auto_sales.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

use crate::common::{execute_query, fetch_results, print_utils};
use crate::fixtures::*;
use crate::fixtures::{db::Query, S3};
use anyhow::{Context, Result};
use approx::assert_relative_eq;
use datafusion::arrow::record_batch::RecordBatch;
Expand All @@ -40,6 +39,7 @@ use datafusion::parquet::file::properties::WriterProperties;
use std::fs::File;

const YEARS: [i32; 5] = [2020, 2021, 2022, 2023, 2024];

const MANUFACTURERS: [&str; 10] = [
"Toyota",
"Honda",
Expand All @@ -52,6 +52,7 @@ const MANUFACTURERS: [&str; 10] = [
"Hyundai",
"Kia",
];

const MODELS: [&str; 20] = [
"Sedan",
"SUV",
Expand Down Expand Up @@ -91,6 +92,7 @@ pub struct AutoSale {
pub struct AutoSalesSimulator;

impl AutoSalesSimulator {
#[allow(dead_code)]
pub fn generate_data_chunk(chunk_size: usize) -> impl Iterator<Item = AutoSale> {
let mut rng = rand::thread_rng();

Expand Down Expand Up @@ -121,6 +123,7 @@ impl AutoSalesSimulator {
})
}

#[allow(dead_code)]
pub fn save_to_parquet_in_batches(
num_records: usize,
chunk_size: usize,
Expand Down Expand Up @@ -220,6 +223,7 @@ impl AutoSalesSimulator {
pub struct AutoSalesTestRunner;

impl AutoSalesTestRunner {
#[allow(dead_code)]
pub async fn create_partition_and_upload_to_s3(
s3: &S3,
s3_bucket: &str,
Expand Down Expand Up @@ -255,33 +259,35 @@ impl AutoSalesTestRunner {
Ok(())
}

#[allow(dead_code)]
pub async fn teardown_tables(conn: &mut PgConnection) -> Result<()> {
// Drop the partitioned table (this will also drop all its partitions)
let drop_partitioned_table = r#"
DROP TABLE IF EXISTS auto_sales_partitioned CASCADE;
"#;
execute_query(conn, drop_partitioned_table).await?;
drop_partitioned_table.execute_result(conn)?;

// Drop the foreign data wrapper and server
let drop_fdw_and_server = r#"
DROP SERVER IF EXISTS auto_sales_server CASCADE;
"#;
execute_query(conn, drop_fdw_and_server).await?;
drop_fdw_and_server.execute_result(conn)?;

let drop_fdw_and_server = r#"
let drop_parquet_wrapper = r#"
DROP FOREIGN DATA WRAPPER IF EXISTS parquet_wrapper CASCADE;
"#;
execute_query(conn, drop_fdw_and_server).await?;
drop_parquet_wrapper.execute_result(conn)?;

// Drop the user mapping
let drop_user_mapping = r#"
DROP USER MAPPING IF EXISTS FOR public SERVER auto_sales_server;
"#;
execute_query(conn, drop_user_mapping).await?;
drop_user_mapping.execute_result(conn)?;

Ok(())
}

#[allow(dead_code)]
pub async fn setup_tables(conn: &mut PgConnection, s3: &S3, s3_bucket: &str) -> Result<()> {
// First, tear down any existing tables
Self::teardown_tables(conn).await?;
Expand All @@ -291,21 +297,19 @@ impl AutoSalesTestRunner {
for command in s3_fdw_setup.split(';') {
let trimmed_command = command.trim();
if !trimmed_command.is_empty() {
execute_query(conn, trimmed_command).await?;
trimmed_command.execute_result(conn)?;
}
}

execute_query(conn, &Self::create_partitioned_table()).await?;
Self::create_partitioned_table().execute_result(conn)?;

// Create partitions
for year in YEARS {
execute_query(conn, &Self::create_year_partition(year)).await?;
Self::create_year_partition(year).execute_result(conn)?;

for manufacturer in MANUFACTURERS {
execute_query(
conn,
&Self::create_manufacturer_partition(s3_bucket, year, manufacturer),
)
.await?;
Self::create_manufacturer_partition(s3_bucket, year, manufacturer)
.execute_result(conn)?;
}
}

Expand Down Expand Up @@ -382,6 +386,7 @@ impl AutoSalesTestRunner {
impl AutoSalesTestRunner {
/// Asserts that the total sales calculated from `pg_analytics`
/// match the expected results from the DataFrame.
#[allow(dead_code)]
pub async fn assert_total_sales(
conn: &mut PgConnection,
df_sales_data: &DataFrame,
Expand All @@ -401,8 +406,7 @@ impl AutoSalesTestRunner {
);

// Execute the SQL query and fetch results from PostgreSQL.
let total_sales_results: Vec<(i32, String, f64)> =
fetch_results(conn, total_sales_query).await?;
let total_sales_results: Vec<(i32, String, f64)> = total_sales_query.fetch(conn);

// Perform the same calculations on the DataFrame.
let df_result = df_sales_data
Expand Down Expand Up @@ -456,20 +460,6 @@ impl AutoSalesTestRunner {
})
.collect();

// Print the results from both PostgreSQL and DataFrame for comparison.
print_utils::print_results(
vec![
"Year".to_string(),
"Manufacturer".to_string(),
"Total Sales".to_string(),
],
"Pg_Analytics".to_string(),
&total_sales_results,
"DataFrame".to_string(),
&expected_results,
)
.await?;

// Compare the results with a small epsilon for floating-point precision.
for ((pg_year, pg_manufacturer, pg_total), (df_year, df_manufacturer, df_total)) in
total_sales_results.iter().zip(expected_results.iter())
Expand All @@ -484,6 +474,7 @@ impl AutoSalesTestRunner {

/// Asserts that the average price calculated from `pg_analytics`
/// matches the expected results from the DataFrame.
#[allow(dead_code)]
pub async fn assert_avg_price(
conn: &mut PgConnection,
df_sales_data: &DataFrame,
Expand All @@ -503,7 +494,7 @@ impl AutoSalesTestRunner {
);

// Execute the SQL query and fetch results from PostgreSQL.
let avg_price_results: Vec<(String, f64)> = fetch_results(conn, avg_price_query).await?;
let avg_price_results: Vec<(String, f64)> = avg_price_query.fetch(conn);

// Perform the same calculations on the DataFrame.
let df_result = df_sales_data
Expand Down Expand Up @@ -547,16 +538,6 @@ impl AutoSalesTestRunner {
})
.collect();

// Print the results from both PostgreSQL and DataFrame for comparison.
print_utils::print_results(
vec!["Manufacturer".to_string(), "Average Price".to_string()],
"Pg_Analytics".to_string(),
&avg_price_results,
"DataFrame".to_string(),
&expected_results,
)
.await?;

// Compare the results using assert_relative_eq for floating-point precision.
for ((pg_manufacturer, pg_price), (df_manufacturer, df_price)) in
avg_price_results.iter().zip(expected_results.iter())
Expand All @@ -570,6 +551,7 @@ impl AutoSalesTestRunner {

/// Asserts that the monthly sales calculated from `pg_analytics`
/// match the expected results from the DataFrame.
#[allow(dead_code)]
pub async fn assert_monthly_sales(
conn: &mut PgConnection,
df_sales_data: &DataFrame,
Expand All @@ -590,8 +572,7 @@ impl AutoSalesTestRunner {
);

// Execute the SQL query and fetch results from PostgreSQL.
let monthly_sales_results: Vec<(i32, i32, i64, Vec<i64>)> =
fetch_results(conn, monthly_sales_query).await?;
let monthly_sales_results: Vec<(i32, i32, i64, Vec<i64>)> = monthly_sales_query.fetch(conn);

// Perform the same calculations on the DataFrame.
let df_result = df_sales_data
Expand Down Expand Up @@ -659,21 +640,6 @@ impl AutoSalesTestRunner {
})
.collect();

// Print the results from both PostgreSQL and DataFrame for comparison.
print_utils::print_results(
vec![
"Year".to_string(),
"Month".to_string(),
"Sales Count".to_string(),
"Sale IDs (first 5)".to_string(),
],
"Pg_Analytics".to_string(),
&monthly_sales_results,
"DataFrame".to_string(),
&expected_results,
)
.await?;

// Assert that the results from PostgreSQL match the DataFrame results.
assert_eq!(
monthly_sales_results, expected_results,
Expand Down
Loading

0 comments on commit e947e42

Please sign in to comment.