diff --git a/crates/arroyo-connectors/src/lib.rs b/crates/arroyo-connectors/src/lib.rs index 191d66627..bc4892269 100644 --- a/crates/arroyo-connectors/src/lib.rs +++ b/crates/arroyo-connectors/src/lib.rs @@ -1,13 +1,11 @@ use anyhow::{anyhow, bail, Context}; -use arrow::array::{ArrayRef, RecordBatch}; use arroyo_operator::connector::ErasedConnector; use arroyo_rpc::api_types::connections::{ ConnectionSchema, ConnectionType, FieldType, SourceField, SourceFieldType, TestSourceMessage, }; use arroyo_rpc::primitive_to_sql; use arroyo_rpc::var_str::VarStr; -use arroyo_types::{string_to_map, SourceError}; -use async_trait::async_trait; +use arroyo_types::string_to_map; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::Client; use serde::{Deserialize, Serialize}; diff --git a/crates/arroyo-connectors/src/redis/lookup.rs b/crates/arroyo-connectors/src/redis/lookup.rs index 106fc6108..f14903c3e 100644 --- a/crates/arroyo-connectors/src/redis/lookup.rs +++ b/crates/arroyo-connectors/src/redis/lookup.rs @@ -49,7 +49,8 @@ impl LookupConnector for RedisLookup { for v in vs { match v { Value::Nil => { - self.deserializer.deserialize_slice("null".as_bytes(), SystemTime::now(), None) + self.deserializer + .deserialize_slice("null".as_bytes(), SystemTime::now(), None) .await; } Value::SimpleString(s) => { diff --git a/crates/arroyo-formats/src/avro/de.rs b/crates/arroyo-formats/src/avro/de.rs index f371bf70c..4c2bfd115 100644 --- a/crates/arroyo-formats/src/avro/de.rs +++ b/crates/arroyo-formats/src/avro/de.rs @@ -214,7 +214,7 @@ mod tests { fn deserializer_with_schema( format: AvroFormat, writer_schema: Option<&str>, - ) -> (ArrowDeserializer, Vec>, ArroyoSchema) { + ) -> (ArrowDeserializer, ArroyoSchema) { let arrow_schema = if format.into_unstructured_json { Schema::new(vec![Field::new("value", DataType::Utf8, false)]) } else { @@ -239,13 +239,6 @@ mod tests { ArroyoSchema::from_schema_keys(Arc::new(Schema::new(fields)), vec![]).unwrap() }; - let builders: Vec<_> = arroyo_schema - .schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 8)) - .collect(); - let resolver: Arc = if let Some(schema) = &writer_schema { Arc::new(FixedSchemaResolver::new( if format.confluent_schema_registry { @@ -263,11 +256,10 @@ mod tests { ArrowDeserializer::with_schema_resolver( Format::Avro(format), None, - arroyo_schema.clone(), + Arc::new(arroyo_schema.clone()), BadData::Fail {}, resolver, ), - builders, arroyo_schema, ) } @@ -277,23 +269,15 @@ mod tests { writer_schema: Option<&str>, message: &[u8], ) -> Vec> { - let (mut deserializer, mut builders, arroyo_schema) = + let (mut deserializer, arroyo_schema) = deserializer_with_schema(format.clone(), writer_schema); let errors = deserializer - .deserialize_slice(&mut builders, message, SystemTime::now(), None) + .deserialize_slice(message, SystemTime::now(), None) .await; assert_eq!(errors, vec![]); - let batch = if format.into_unstructured_json { - RecordBatch::try_new( - arroyo_schema.schema, - builders.into_iter().map(|mut b| b.finish()).collect(), - ) - .unwrap() - } else { - deserializer.flush_buffer().unwrap().unwrap() - }; + let batch = deserializer.flush_buffer().unwrap().unwrap(); record_batch_to_vec(&batch, true, arrow_json::writer::TimestampFormat::RFC3339) .unwrap() diff --git a/crates/arroyo-formats/src/de.rs b/crates/arroyo-formats/src/de.rs index 052662d0d..9eb02a366 100644 --- a/crates/arroyo-formats/src/de.rs +++ b/crates/arroyo-formats/src/de.rs @@ -7,7 +7,7 @@ use arrow_array::builder::{ make_builder, ArrayBuilder, GenericByteBuilder, StringBuilder, TimestampNanosecondBuilder, }; use arrow_array::types::GenericBinaryType; -use arrow_array::RecordBatch; +use arrow_array::{ArrayRef, BooleanArray, RecordBatch}; use arrow_schema::SchemaRef; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ @@ -33,7 +33,6 @@ pub enum FieldValueType<'a> { struct ContextBuffer { buffer: Vec>, created: Instant, - schema: SchemaRef, } impl ContextBuffer { @@ -47,7 +46,6 @@ impl ContextBuffer { Self { buffer, created: Instant::now(), - schema, } } @@ -59,12 +57,8 @@ impl ContextBuffer { should_flush(self.size(), self.created) } - pub fn finish(&mut self) -> RecordBatch { - RecordBatch::try_new( - self.schema.clone(), - self.buffer.iter_mut().map(|a| a.finish()).collect(), - ) - .unwrap() + pub fn finish(&mut self) -> Vec { + self.buffer.iter_mut().map(|a| a.finish()).collect() } } @@ -119,19 +113,111 @@ impl<'a> Iterator for FramingIterator<'a> { } } +enum BufferDecoder { + Buffer(ContextBuffer), + JsonDecoder { + decoder: arrow::json::reader::Decoder, + buffered_count: usize, + buffered_since: Instant, + }, +} + +impl BufferDecoder { + fn should_flush(&self) -> bool { + match self { + BufferDecoder::Buffer(b) => b.should_flush(), + BufferDecoder::JsonDecoder { + buffered_count, + buffered_since, + .. + } => should_flush(*buffered_count, *buffered_since), + } + } + + fn flush( + &mut self, + bad_data: &BadData, + ) -> Option, Option), SourceError>> { + match self { + BufferDecoder::Buffer(buffer) => { + if buffer.size() > 0 { + Some(Ok((buffer.finish(), None))) + } else { + None + } + } + BufferDecoder::JsonDecoder { + decoder, + buffered_since, + buffered_count, + } => { + *buffered_since = Instant::now(); + *buffered_count = 0; + Some(match bad_data { + BadData::Fail { .. } => decoder + .flush() + .map_err(|e| { + SourceError::bad_data(format!("JSON does not match schema: {:?}", e)) + }) + .transpose()? + .map(|batch| (batch.columns().to_vec(), None)), + BadData::Drop { .. } => decoder + .flush_with_bad_data() + .map_err(|e| { + SourceError::bad_data(format!( + "Something went wrong decoding JSON: {:?}", + e + )) + }) + .transpose()? + .map(|(batch, mask, _)| (batch.columns().to_vec(), Some(mask))), + }) + } + } + } + + fn decode_json(&mut self, msg: &[u8]) -> Result<(), SourceError> { + match self { + BufferDecoder::Buffer(_) => { + unreachable!("Tried to decode JSON for non-JSON deserializer"); + } + BufferDecoder::JsonDecoder { + decoder, + buffered_count, + .. + } => { + decoder + .decode(msg) + .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; + + *buffered_count += 1; + + Ok(()) + } + } + } + + fn get_buffer(&mut self) -> &mut ContextBuffer { + match self { + BufferDecoder::Buffer(buffer) => buffer, + BufferDecoder::JsonDecoder { .. } => { + panic!("tried to get a raw buffer from a JSON deserializer"); + } + } + } +} + pub struct ArrowDeserializer { format: Arc, framing: Option>, schema: Arc, bad_data: BadData, - json_decoder: Option<(arrow::json::reader::Decoder, TimestampNanosecondBuilder)>, - buffered_count: usize, - buffered_since: Instant, schema_registry: Arc>>, proto_pool: DescriptorPool, schema_resolver: Arc, additional_fields_builder: Option>>, - buffer: ContextBuffer, + timestamp_builder: Option, + buffer_decoder: BufferDecoder, } impl ArrowDeserializer { @@ -172,43 +258,42 @@ impl ArrowDeserializer { DescriptorPool::global() }; + let buffer_decoder = match format { + Format::Json(..) + | Format::Avro(AvroFormat { + into_unstructured_json: false, + .. + }) + | Format::Protobuf(ProtobufFormat { + into_unstructured_json: false, + .. + }) => BufferDecoder::JsonDecoder { + decoder: arrow_json::reader::ReaderBuilder::new(Arc::new( + schema.schema_without_timestamp(), + )) + .with_limit_to_batch_size(false) + .with_strict_mode(false) + .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) + .build_decoder() + .unwrap(), + buffered_count: 0, + buffered_since: Instant::now(), + }, + _ => BufferDecoder::Buffer(ContextBuffer::new(Arc::new( + schema.schema_without_timestamp(), + ))), + }; + Self { - json_decoder: matches!( - format, - Format::Json(..) - | Format::Avro(AvroFormat { - into_unstructured_json: false, - .. - }) - | Format::Protobuf(ProtobufFormat { - into_unstructured_json: false, - .. - }) - ) - .then(|| { - // exclude the timestamp field - ( - arrow_json::reader::ReaderBuilder::new(Arc::new( - schema.schema_without_timestamp(), - )) - .with_limit_to_batch_size(false) - .with_strict_mode(false) - .with_allow_bad_data(matches!(bad_data, BadData::Drop { .. })) - .build_decoder() - .unwrap(), - TimestampNanosecondBuilder::new(), - ) - }), format: Arc::new(format), framing: framing.map(Arc::new), - buffer: ContextBuffer::new(schema.schema.clone()), + buffer_decoder, + timestamp_builder: Some(TimestampNanosecondBuilder::with_capacity(128)), schema, schema_registry: Arc::new(Mutex::new(HashMap::new())), bad_data, schema_resolver, proto_pool, - buffered_count: 0, - buffered_since: Instant::now(), additional_fields_builder: None, } } @@ -219,108 +304,120 @@ impl ArrowDeserializer { timestamp: SystemTime, additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, ) -> Vec { - match &*self.format { - Format::Avro(_) => self.deserialize_slice_avro(msg, timestamp).await, - _ => FramingIterator::new(self.framing.clone(), msg) - .map(|t| self.deserialize_single(t, timestamp, additional_fields)) - .filter_map(|t| t.err()) - .collect(), + self.deserialize_slice_int(msg, Some(timestamp), additional_fields) + .await + } + + async fn deserialize_slice_int( + &mut self, + msg: &[u8], + timestamp: Option, + additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, + ) -> Vec { + let (count, errors) = match &*self.format { + Format::Avro(_) => self.deserialize_slice_avro(msg).await, + _ => { + let mut count = 0; + let errors = FramingIterator::new(self.framing.clone(), msg) + .map(|t| self.deserialize_single(t)) + .filter_map(|t| { + if t.is_ok() { + count += 1; + } + t.err() + }) + .collect(); + (count, errors) + } + }; + + if let Some(timestamp) = timestamp { + let b = self + .timestamp_builder + .as_mut() + .expect("tried to serialize timestamp to a schema without a timestamp column"); + + for _ in 0..count { + b.append_value(to_nanos(timestamp) as i64); + } } + + if let Some(additional_fields) = additional_fields { + if self.additional_fields_builder.is_none() { + let mut builders = HashMap::new(); + for (key, value) in additional_fields.iter() { + let builder: Box = match value { + FieldValueType::Int32(_) => Box::new(Int32Builder::new()), + FieldValueType::Int64(_) => Box::new(Int64Builder::new()), + FieldValueType::String(_) => Box::new(StringBuilder::new()), + }; + builders.insert(key.to_string(), builder); + } + self.additional_fields_builder = Some(builders); + } + + let builders = self.additional_fields_builder.as_mut().unwrap(); + + for (k, v) in additional_fields { + add_additional_fields(builders, k, v, count); + } + } + + errors } pub fn should_flush(&self) -> bool { - self.buffer.should_flush() || should_flush(self.buffered_count, self.buffered_since) + self.buffer_decoder.should_flush() } pub fn flush_buffer(&mut self) -> Option> { - if self.buffer.size() > 0 { - return Some(Ok(self.buffer.finish())); - } + let (mut arrays, error_mask) = match self.buffer_decoder.flush(&self.bad_data)? { + Ok((a, b)) => (a, b), + Err(e) => return Some(Err(e)), + }; - let (decoder, timestamp) = self.json_decoder.as_mut()?; - self.buffered_since = Instant::now(); - self.buffered_count = 0; - match self.bad_data { - BadData::Fail { .. } => Some( - decoder - .flush() - .map_err(|e| { - SourceError::bad_data(format!("JSON does not match schema: {:?}", e)) - }) - .transpose()? - .map(|batch| { - let mut columns = batch.columns().to_vec(); - columns.insert(self.schema.timestamp_index, Arc::new(timestamp.finish())); - flush_additional_fields_builders( - &mut self.additional_fields_builder, - &self.schema, - &mut columns, - ); - RecordBatch::try_new(self.schema.schema.clone(), columns).unwrap() - }), - ), - BadData::Drop { .. } => Some( - decoder - .flush_with_bad_data() - .map_err(|e| { - SourceError::bad_data(format!( - "Something went wrong decoding JSON: {:?}", - e - )) - }) - .transpose()? - .map(|(batch, mask, _)| { - let mut columns = batch.columns().to_vec(); - let timestamp = - kernels::filter::filter(×tamp.finish(), &mask).unwrap(); - - columns.insert(self.schema.timestamp_index, Arc::new(timestamp)); - flush_additional_fields_builders( - &mut self.additional_fields_builder, - &self.schema, - &mut columns, - ); - RecordBatch::try_new(self.schema.schema.clone(), columns).unwrap() - }), - ), + if let Some(additional_fields) = &mut self.additional_fields_builder { + for (name, builder) in additional_fields { + let (idx, _) = self + .schema + .schema + .column_with_name(&name) + .unwrap_or_else(|| panic!("Field '{}' not found in schema", name)); + + let mut array = builder.finish(); + if let Some(error_mask) = &error_mask { + array = kernels::filter::filter(&array, error_mask).unwrap(); + } + + arrays[idx] = array; + } + }; + + if let Some(timestamp) = &mut self.timestamp_builder { + let array = if let Some(error_mask) = &error_mask { + kernels::filter::filter(×tamp.finish(), error_mask).unwrap() + } else { + Arc::new(timestamp.finish()) + }; + + arrays.insert(self.schema.timestamp_index, array); } + + Some(Ok( + RecordBatch::try_new(self.schema.schema.clone(), arrays).unwrap() + )) } - fn deserialize_single( - &mut self, - msg: &[u8], - timestamp: SystemTime, - additional_fields: Option<&HashMap<&String, FieldValueType>>, - ) -> Result<(), SourceError> { + fn deserialize_single(&mut self, msg: &[u8]) -> Result<(), SourceError> { match &*self.format { Format::RawString(_) | Format::Json(JsonFormat { unstructured: true, .. }) => { self.deserialize_raw_string(msg); - add_timestamp( - &mut self.buffer.buffer, - self.schema.timestamp_index, - timestamp, - ); - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); - } - } } Format::RawBytes(_) => { self.deserialize_raw_bytes(msg); - add_timestamp( - &mut self.buffer.buffer, - self.schema.timestamp_index, - timestamp, - ); - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - add_additional_fields(&mut self.buffer.buffer, &self.schema, k, v); - } - } } Format::Json(json) => { let msg = if json.confluent_schema_registry { @@ -329,62 +426,17 @@ impl ArrowDeserializer { msg }; - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - if self.additional_fields_builder.is_none() { - if let Some(fields) = additional_fields.as_ref() { - let mut builders = HashMap::new(); - for (key, value) in fields.iter() { - let builder: Box = match value { - FieldValueType::Int32(_) => Box::new(Int32Builder::new()), - FieldValueType::Int64(_) => Box::new(Int64Builder::new()), - FieldValueType::String(_) => Box::new(StringBuilder::new()), - }; - builders.insert(key, builder); - } - self.additional_fields_builder = Some( - builders - .into_iter() - .map(|(k, v)| ((*k).clone(), v)) - .collect(), - ); - } - } - - decoder - .decode(msg) - .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - timestamp_builder.append_value(to_nanos(timestamp) as i64); - - add_additional_fields_using_builder( - additional_fields, - &mut self.additional_fields_builder, - ); - self.buffered_count += 1; + self.buffer_decoder.decode_json(msg)?; } Format::Protobuf(proto) => { let json = proto::de::deserialize_proto(&mut self.proto_pool, proto, msg)?; if proto.into_unstructured_json { - self.decode_into_json(json, timestamp); + self.decode_into_json(json); } else { - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - decoder - .decode(json.to_string().as_bytes()) + self.buffer_decoder + .decode_json(json.to_string().as_bytes()) .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - timestamp_builder.append_value(to_nanos(timestamp) as i64); - - add_additional_fields_using_builder( - additional_fields, - &mut self.additional_fields_builder, - ); - - self.buffered_count += 1; } } Format::Avro(_) => unreachable!("this should not be called for avro"), @@ -394,31 +446,21 @@ impl ArrowDeserializer { Ok(()) } - fn decode_into_json(&mut self, value: Value, timestamp: SystemTime) { + fn decode_into_json(&mut self, value: Value) { let (idx, _) = self .schema .schema .column_with_name("value") .expect("no 'value' column for unstructured avro"); - let array = self.buffer.buffer[idx] + let array = self.buffer_decoder.get_buffer().buffer[idx] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type"); array.append_value(value.to_string()); - add_timestamp( - &mut self.buffer.buffer, - self.schema.timestamp_index, - timestamp, - ); - self.buffered_count += 1; } - pub async fn deserialize_slice_avro<'a>( - &mut self, - msg: &'a [u8], - timestamp: SystemTime, - ) -> Vec { + async fn deserialize_slice_avro(&mut self, msg: &[u8]) -> (usize, Vec) { let Format::Avro(format) = &*self.format else { unreachable!("not avro"); }; @@ -433,13 +475,14 @@ impl ArrowDeserializer { { Ok(messages) => messages, Err(e) => { - return vec![e]; + return (0, vec![e]); } }; let into_json = format.into_unstructured_json; - messages + let mut count = 0; + let errors = messages .into_iter() .map(|record| { let value = record.map_err(|e| { @@ -447,27 +490,25 @@ impl ArrowDeserializer { })?; if into_json { - self.decode_into_json(de::avro_to_json(value), timestamp); + self.decode_into_json(de::avro_to_json(value)); } else { // for now round-trip through json in order to handle unsupported avro features // as that allows us to rely on raw json deserialization let json = de::avro_to_json(value).to_string(); - let Some((decoder, timestamp_builder)) = &mut self.json_decoder else { - panic!("json decoder not initialized"); - }; - - decoder - .decode(json.as_bytes()) + self.buffer_decoder + .decode_json(json.as_bytes()) .map_err(|e| SourceError::bad_data(format!("invalid JSON: {:?}", e)))?; - self.buffered_count += 1; - timestamp_builder.append_value(to_nanos(timestamp) as i64); } + count += 1; + Ok(()) }) .filter_map(|r: Result<(), SourceError>| r.err()) - .collect() + .collect(); + + (count, errors) } fn deserialize_raw_string(&mut self, msg: &[u8]) { @@ -476,7 +517,7 @@ impl ArrowDeserializer { .schema .column_with_name("value") .expect("no 'value' column for RawString format"); - self.buffer.buffer[col] + self.buffer_decoder.get_buffer().buffer[col] .as_any_mut() .downcast_mut::() .expect("'value' column has incorrect type") @@ -489,7 +530,7 @@ impl ArrowDeserializer { .schema .column_with_name("value") .expect("no 'value' column for RawBytes format"); - self.buffer.buffer[col] + self.buffer_decoder.get_buffer().buffer[col] .as_any_mut() .downcast_mut::>>() .expect("'value' column has incorrect type") @@ -501,111 +542,42 @@ impl ArrowDeserializer { } } -pub(crate) fn add_timestamp( - builder: &mut [Box], - idx: usize, - timestamp: SystemTime, -) { - builder[idx] - .as_any_mut() - .downcast_mut::() - .expect("_timestamp column has incorrect type") - .append_value(to_nanos(timestamp) as i64); -} - -pub(crate) fn add_additional_fields( - builder: &mut [Box], - schema: &ArroyoSchema, +fn add_additional_fields( + builders: &mut HashMap>, key: &str, value: &FieldValueType<'_>, + count: usize, ) { - let (idx, _) = schema - .schema - .column_with_name(key) - .unwrap_or_else(|| panic!("no '{}' column for additional fields", key)); + let builder = builders + .get_mut(key) + .unwrap_or_else(|| panic!("unexpected additional field '{}'", key)) + .as_any_mut(); match value { FieldValueType::Int32(i) => { - builder[idx] - .as_any_mut() + let b = builder .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); + .expect("additional field has incorrect type"); + + for _ in 0..count { + b.append_value(*i); + } } FieldValueType::Int64(i) => { - builder[idx] - .as_any_mut() + let b = builder .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::String(s) => { - builder[idx] - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(s); - } - } -} + .expect("additional field has incorrect type"); -pub(crate) fn add_additional_fields_using_builder( - additional_fields: Option<&HashMap<&String, FieldValueType<'_>>>, - additional_fields_builder: &mut Option>>, -) { - if let Some(fields) = additional_fields { - for (k, v) in fields.iter() { - if let Some(builder) = additional_fields_builder - .as_mut() - .and_then(|b| b.get_mut(*k)) - { - match v { - FieldValueType::Int32(i) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::Int64(i) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(*i); - } - FieldValueType::String(s) => { - builder - .as_any_mut() - .downcast_mut::() - .expect("additional field has incorrect type") - .append_value(s); - } - } + for _ in 0..count { + b.append_value(*i); } } - } -} + FieldValueType::String(s) => { + let b = builder + .downcast_mut::() + .expect("additional field has incorrect type"); -pub(crate) fn flush_additional_fields_builders( - additional_fields_builder: &mut Option>>, - schema: &ArroyoSchema, - columns: &mut [Arc], -) { - if let Some(additional_fields) = additional_fields_builder.take() { - for (field_name, mut builder) in additional_fields { - if let Some((idx, _)) = schema.schema.column_with_name(&field_name) { - let expected_type = schema.schema.fields[idx].data_type(); - let built_column = builder.as_mut().finish(); - let actual_type = built_column.data_type(); - if expected_type != actual_type { - panic!( - "Type mismatch for column '{}': expected {:?}, got {:?}", - field_name, expected_type, actual_type - ); - } - columns[idx] = Arc::new(built_column); - } else { - panic!("Field '{}' not found in schema", field_name); + for _ in 0..count { + b.append_value(*s); } } } @@ -615,10 +587,8 @@ pub(crate) fn flush_additional_fields_builders( mod tests { use crate::de::{ArrowDeserializer, FieldValueType, FramingIterator}; use arrow::datatypes::Int32Type; - use arrow_array::builder::{make_builder, ArrayBuilder}; use arrow_array::cast::AsArray; use arrow_array::types::{GenericBinaryType, Int64Type, TimestampNanosecondType}; - use arrow_array::RecordBatch; use arrow_schema::{Schema, TimeUnit}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{ @@ -800,12 +770,6 @@ mod tests { ), ])); - let mut arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( @@ -821,8 +785,7 @@ mod tests { .await; assert!(result.is_empty()); - let arrays: Vec<_> = arrays.into_iter().map(|mut a| a.finish()).collect(); - let batch = RecordBatch::try_new(schema, arrays).unwrap(); + let batch = deserializer.flush_buffer().unwrap().unwrap(); assert_eq!(batch.num_rows(), 1); assert_eq!( @@ -852,12 +815,6 @@ mod tests { ), ])); - let mut arrays: Vec<_> = schema - .fields - .iter() - .map(|f| make_builder(f.data_type(), 16)) - .collect(); - let arroyo_schema = Arc::new(ArroyoSchema::from_schema_unkeyed(schema.clone()).unwrap()); let mut deserializer = ArrowDeserializer::new( diff --git a/crates/arroyo-planner/src/extension/lookup.rs b/crates/arroyo-planner/src/extension/lookup.rs index 2d72a05f3..210d19d38 100644 --- a/crates/arroyo-planner/src/extension/lookup.rs +++ b/crates/arroyo-planner/src/extension/lookup.rs @@ -1,6 +1,7 @@ use crate::builder::{NamedNode, Planner}; use crate::extension::{ArroyoExtension, NodeWithIncomingEdges}; use crate::multifield_partial_ord; +use crate::schemas::add_timestamp_field_arrow; use crate::tables::ConnectorTable; use arroyo_datastream::logical::{LogicalEdge, LogicalEdgeType, LogicalNode, OperatorName}; use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; @@ -13,7 +14,6 @@ use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use prost::Message; use std::fmt::Formatter; use std::sync::Arc; -use crate::schemas::{add_timestamp_field_arrow}; pub const SOURCE_EXTENSION_NAME: &str = "LookupSource"; pub const JOIN_EXTENSION_NAME: &str = "LookupJoin"; @@ -88,8 +88,9 @@ impl ArroyoExtension for LookupJoin { input_schemas: Vec, ) -> datafusion::common::Result { let schema = ArroyoSchema::from_schema_unkeyed(Arc::new(self.schema.as_ref().into()))?; - let lookup_schema = ArroyoSchema::from_schema_unkeyed( - add_timestamp_field_arrow(self.connector.physical_schema()))?; + let lookup_schema = ArroyoSchema::from_schema_unkeyed(add_timestamp_field_arrow( + self.connector.physical_schema(), + ))?; let join_config = LookupJoinOperator { input_schema: Some(schema.into()), lookup_schema: Some(lookup_schema.into()), diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index 6b693d06c..49e150bfa 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -37,7 +37,7 @@ pub mod grpc { pub mod api { #![allow(clippy::derive_partial_eq_without_eq, deprecated)] tonic::include_proto!("api"); - + impl From for arroyo_types::JoinType { fn from(value: JoinType) -> Self { match value { diff --git a/crates/arroyo-worker/src/arrow/lookup_join.rs b/crates/arroyo-worker/src/arrow/lookup_join.rs index 05e376c40..1403e0eaf 100644 --- a/crates/arroyo-worker/src/arrow/lookup_join.rs +++ b/crates/arroyo-worker/src/arrow/lookup_join.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; use std::sync::Arc; use arrow::row::{OwnedRow, RowConverter, SortField}; -use arrow_array::RecordBatch; -use arroyo_connectors::{connectors}; +use arrow_array::{Array, RecordBatch}; +use arroyo_connectors::connectors; use arroyo_operator::connector::LookupConnector; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ @@ -105,6 +105,9 @@ impl ArrowOperator for LookupJoin { let mut result = batch.columns().to_vec(); result.extend(right_side); + println!("SCHEMA = {:?}", ctx.out_schema.as_ref().unwrap().schema); + println!("RESULT COLS = {:?}", result.iter().map(|s| s.data_type())); + collector .collect( RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), result) @@ -143,20 +146,29 @@ impl OperatorConstructor for LookupJoinConstructor { let op = config.connector.unwrap(); let operator_config = serde_json::from_str(&op.config)?; - let result_row_converter = RowConverter::new(lookup_schema.schema.fields.iter().map(|f| - SortField::new(f.data_type().clone())).collect())?; - + let result_row_converter = RowConverter::new( + lookup_schema + .schema + .fields + .iter() + .map(|f| SortField::new(f.data_type().clone())) + .collect(), + )?; + let connector = connectors() .get(op.connector.as_str()) .unwrap_or_else(|| panic!("No connector with name '{}'", op.connector)) .make_lookup(operator_config, Arc::new(lookup_schema))?; - + Ok(ConstructedOperator::from_operator(Box::new(LookupJoin { connector, cache: Default::default(), - key_row_converter: RowConverter::new(exprs.iter().map(|e| - Ok(SortField::new(e.data_type(&input_schema.schema)?))) - .collect::>()?)?, + key_row_converter: RowConverter::new( + exprs + .iter() + .map(|e| Ok(SortField::new(e.data_type(&input_schema.schema)?))) + .collect::>()?, + )?, key_exprs: exprs, result_row_converter, join_type: join_type.into(), diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index da147c451..c7ed5ae93 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -782,11 +782,9 @@ pub async fn construct_node( ) -> OperatorNode { if chain.is_source() { let (head, _) = chain.iter().next().unwrap(); - let ConstructedOperator::Source(operator) = construct_operator( - head.operator_name, - &head.operator_config, - registry, - ) else { + let ConstructedOperator::Source(operator) = + construct_operator(head.operator_name, &head.operator_config, registry) + else { unreachable!(); };