Skip to content

Commit

Permalink
Merge pull request #666 from trtt/fix/streamconsumer-wake
Browse files Browse the repository at this point in the history
Fix StreamConsumer wakeup races
  • Loading branch information
davidblewett authored Sep 24, 2024
2 parents ddba5fa + 3f2b74b commit c6d9a65
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 66 deletions.
38 changes: 32 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,21 @@ impl NativeClient {
}
}

pub(crate) enum EventPollResult<T> {
None,
EventConsumed,
Event(T),
}

impl<T> From<EventPollResult<T>> for Option<T> {
fn from(val: EventPollResult<T>) -> Self {
match val {
EventPollResult::None | EventPollResult::EventConsumed => None,
EventPollResult::Event(evt) => Some(evt),
}
}
}

/// A low-level rdkafka client.
///
/// This type is the basis of the consumers and producers in the [`consumer`]
Expand Down Expand Up @@ -278,31 +293,42 @@ impl<C: ClientContext> Client<C> {
&self.context
}

pub(crate) fn poll_event(&self, queue: &NativeQueue, timeout: Timeout) -> Option<NativeEvent> {
pub(crate) fn poll_event(
&self,
queue: &NativeQueue,
timeout: Timeout,
) -> EventPollResult<NativeEvent> {
let event = unsafe { NativeEvent::from_ptr(queue.poll(timeout)) };
if let Some(ev) = event {
let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_LOG => self.handle_log_event(ev.ptr()),
rdsys::RD_KAFKA_EVENT_STATS => self.handle_stats_event(ev.ptr()),
rdsys::RD_KAFKA_EVENT_LOG => {
self.handle_log_event(ev.ptr());
return EventPollResult::EventConsumed;
}
rdsys::RD_KAFKA_EVENT_STATS => {
self.handle_stats_event(ev.ptr());
return EventPollResult::EventConsumed;
}
rdsys::RD_KAFKA_EVENT_ERROR => {
// rdkafka reports consumer errors via RD_KAFKA_EVENT_ERROR but producer errors gets
// embedded on the ack returned via RD_KAFKA_EVENT_DR. Hence we need to return this event
// for the consumer case in order to return the error to the user.
self.handle_error_event(ev.ptr());
return Some(ev);
return EventPollResult::Event(ev);
}
rdsys::RD_KAFKA_EVENT_OAUTHBEARER_TOKEN_REFRESH => {
if C::ENABLE_REFRESH_OAUTH_TOKEN {
self.handle_oauth_refresh_event(ev.ptr());
}
return EventPollResult::EventConsumed;
}
_ => {
return Some(ev);
return EventPollResult::Event(ev);
}
}
}
None
EventPollResult::None
}

fn handle_log_event(&self, event: *mut RDKafkaEvent) {
Expand Down
84 changes: 47 additions & 37 deletions src/consumer/base_consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use log::{error, warn};
use rdkafka_sys as rdsys;
use rdkafka_sys::types::*;

use crate::client::{Client, NativeClient, NativeQueue};
use crate::client::{Client, EventPollResult, NativeClient, NativeQueue};
use crate::config::{
ClientConfig, FromClientConfig, FromClientConfigAndContext, NativeClientConfig,
};
Expand Down Expand Up @@ -117,59 +117,69 @@ where
///
/// The returned message lives in the memory of the consumer and cannot outlive it.
pub fn poll<T: Into<Timeout>>(&self, timeout: T) -> Option<KafkaResult<BorrowedMessage<'_>>> {
self.poll_queue(self.get_queue(), timeout)
self.poll_queue(self.get_queue(), timeout).into()
}

pub(crate) fn poll_queue<T: Into<Timeout>>(
&self,
queue: &NativeQueue,
timeout: T,
) -> Option<KafkaResult<BorrowedMessage<'_>>> {
) -> EventPollResult<KafkaResult<BorrowedMessage<'_>>> {
let now = Instant::now();
let mut timeout = timeout.into();
let initial_timeout = timeout.into();
let mut timeout = initial_timeout;
let min_poll_interval = self.context().main_queue_min_poll_interval();
loop {
let op_timeout = std::cmp::min(timeout, min_poll_interval);
let maybe_event = self.client().poll_event(queue, op_timeout);
if let Some(event) = maybe_event {
let evtype = unsafe { rdsys::rd_kafka_event_type(event.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_FETCH => {
if let Some(result) = self.handle_fetch_event(event) {
return Some(result);
match maybe_event {
EventPollResult::Event(event) => {
let evtype = unsafe { rdsys::rd_kafka_event_type(event.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_FETCH => {
if let Some(result) = self.handle_fetch_event(event) {
return EventPollResult::Event(result);
}
}
}
rdsys::RD_KAFKA_EVENT_ERROR => {
if let Some(err) = self.handle_error_event(event) {
return Some(Err(err));
rdsys::RD_KAFKA_EVENT_ERROR => {
if let Some(err) = self.handle_error_event(event) {
return EventPollResult::Event(Err(err));
}
}
}
rdsys::RD_KAFKA_EVENT_REBALANCE => {
self.handle_rebalance_event(event);
if timeout != Timeout::Never {
return None;
rdsys::RD_KAFKA_EVENT_REBALANCE => {
self.handle_rebalance_event(event);
if timeout != Timeout::Never {
return EventPollResult::EventConsumed;
}
}
}
rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT => {
self.handle_offset_commit_event(event);
if timeout != Timeout::Never {
return None;
rdsys::RD_KAFKA_EVENT_OFFSET_COMMIT => {
self.handle_offset_commit_event(event);
if timeout != Timeout::Never {
return EventPollResult::EventConsumed;
}
}
_ => {
let evname = unsafe {
let evname = rdsys::rd_kafka_event_name(event.ptr());
CStr::from_ptr(evname).to_string_lossy()
};
warn!("Ignored event '{evname}' on consumer poll");
}
}
_ => {
let evname = unsafe {
let evname = rdsys::rd_kafka_event_name(event.ptr());
CStr::from_ptr(evname).to_string_lossy()
};
warn!("Ignored event '{evname}' on consumer poll");
}
EventPollResult::None => {
timeout = initial_timeout.saturating_sub(now.elapsed());
if timeout.is_zero() {
return EventPollResult::None;
}
}
}

timeout = timeout.saturating_sub(now.elapsed());
if timeout.is_zero() {
return None;
}
EventPollResult::EventConsumed => {
timeout = initial_timeout.saturating_sub(now.elapsed());
if timeout.is_zero() {
return EventPollResult::EventConsumed;
}
}
};
}
}

Expand Down Expand Up @@ -836,7 +846,7 @@ where
/// associated consumer regularly, even if no messages are expected, to
/// serve events.
pub fn poll<T: Into<Timeout>>(&self, timeout: T) -> Option<KafkaResult<BorrowedMessage<'_>>> {
self.consumer.poll_queue(&self.queue, timeout)
self.consumer.poll_queue(&self.queue, timeout).into()
}

/// Sets a callback that will be invoked whenever the queue becomes
Expand Down
56 changes: 35 additions & 21 deletions src/consumer/stream_consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use slab::Slab;
use rdkafka_sys as rdsys;
use rdkafka_sys::types::*;

use crate::client::{Client, NativeQueue};
use crate::client::{Client, EventPollResult, NativeQueue};
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
use crate::consumer::base_consumer::{BaseConsumer, PartitionQueue};
use crate::consumer::{
Expand Down Expand Up @@ -122,11 +122,12 @@ impl<'a, C: ConsumerContext> MessageStream<'a, C> {
}
}

fn poll(&self) -> Option<KafkaResult<BorrowedMessage<'a>>> {
fn poll(&self) -> EventPollResult<KafkaResult<BorrowedMessage<'a>>> {
if let Some(queue) = self.partition_queue {
self.consumer.poll_queue(queue, Duration::ZERO)
} else {
self.consumer.poll(Duration::ZERO)
self.consumer
.poll_queue(self.consumer.get_queue(), Duration::ZERO)
}
}
}
Expand All @@ -135,25 +136,38 @@ impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> {
type Item = KafkaResult<BorrowedMessage<'a>>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// If there is a message ready, yield it immediately to avoid the
// taking the lock in `self.set_waker`.
if let Some(message) = self.poll() {
return Poll::Ready(Some(message));
}

// Otherwise, we need to wait for a message to become available. Store
// the waker so that we are woken up if the queue flips from non-empty
// to empty. We have to store the waker repatedly in case this future
// migrates between tasks.
self.wakers.set_waker(self.slot, cx.waker().clone());

// Check whether a new message became available after we installed the
// waker. This avoids a race where `poll` returns None to indicate that
// the queue is empty, but the queue becomes non-empty before we've
// installed the waker.
match self.poll() {
None => Poll::Pending,
Some(message) => Poll::Ready(Some(message)),
EventPollResult::Event(message) => {
// If there is a message ready, yield it immediately to avoid the
// taking the lock in `self.set_waker`.
Poll::Ready(Some(message))
}
EventPollResult::EventConsumed => {
// Event was consumed, yield to runtime
cx.waker().wake_by_ref();
Poll::Pending
}
EventPollResult::None => {
// Otherwise, we need to wait for a message to become available. Store
// the waker so that we are woken up if the queue flips from non-empty
// to empty. We have to store the waker repatedly in case this future
// migrates between tasks.
self.wakers.set_waker(self.slot, cx.waker().clone());

// Check whether a new message became available after we installed the
// waker. This avoids a race where `poll` returns None to indicate that
// the queue is empty, but the queue becomes non-empty before we've
// installed the waker.
match self.poll() {
EventPollResult::Event(message) => Poll::Ready(Some(message)),
EventPollResult::EventConsumed => {
// Event was consumed, yield to runtime
cx.waker().wake_by_ref();
Poll::Pending
}
EventPollResult::None => Poll::Pending,
}
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/producer/base_producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ use rdkafka_sys as rdsys;
use rdkafka_sys::rd_kafka_vtype_t::*;
use rdkafka_sys::types::*;

use crate::client::{Client, NativeQueue};
use crate::client::{Client, EventPollResult, NativeQueue};
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
use crate::consumer::ConsumerGroupMetadata;
use crate::error::{IsError, KafkaError, KafkaResult, RDKafkaError};
Expand Down Expand Up @@ -363,7 +363,7 @@ where
/// the message delivery callbacks.
pub fn poll<T: Into<Timeout>>(&self, timeout: T) {
let event = self.client().poll_event(&self.queue, timeout.into());
if let Some(ev) = event {
if let EventPollResult::Event(ev) = event {
let evtype = unsafe { rdsys::rd_kafka_event_type(ev.ptr()) };
match evtype {
rdsys::RD_KAFKA_EVENT_DR => self.handle_delivery_report_event(ev),
Expand Down

0 comments on commit c6d9a65

Please sign in to comment.