Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/dbsp/src/circuit/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ impl Consensus {
notify_receiver,
exchange,
} => {
while !exchange.try_send_all(Runtime::worker_index(), &mut repeat(local)) {
while !exchange.try_send_all(Runtime::worker_index(), repeat(local)) {
if Runtime::kill_in_progress() {
return Err(SchedulerError::Killed);
}
Expand Down Expand Up @@ -1137,7 +1137,7 @@ where
notify_receiver,
exchange,
} => {
while !exchange.try_send_all(Runtime::worker_index(), &mut repeat(local.clone())) {
while !exchange.try_send_all(Runtime::worker_index(), repeat(local.clone())) {
if Runtime::kill_in_progress() {
return Err(SchedulerError::Killed);
}
Expand Down
112 changes: 72 additions & 40 deletions crates/dbsp/src/operator/communication/exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::{
};
use crossbeam_utils::CachePadded;
use futures::{future, prelude::*, stream::FuturesUnordered};
use itertools::Itertools;
use std::{
borrow::Cow,
collections::HashMap,
Expand Down Expand Up @@ -238,10 +239,6 @@ struct InnerExchange {
/// A callback that takes the raw data exchanged over RPC and deserializes
/// and delivers it to the receiver's mailbox.
deliver: Box<dyn Fn(Vec<u8>, usize, usize) + Send + Sync + 'static>,
/// The amount of time spent in `deliver`.
delivery_usecs: AtomicU64,
/// The number of bytes passed to `deliver`.
delivered_bytes: AtomicUsize,
}

impl InnerExchange {
Expand Down Expand Up @@ -270,8 +267,6 @@ impl InnerExchange {
.collect(),
sender_callbacks: (0..npeers).map(|_| Callback::empty()).collect(),
deliver: Box::new(deliver),
delivery_usecs: AtomicU64::new(0),
delivered_bytes: AtomicUsize::new(0),
sent: AtomicUsize::new(0),
}
}
Expand Down Expand Up @@ -304,19 +299,12 @@ impl InnerExchange {
let receivers = &self.local_workers;

// Deliver all of the data into the exchange's mailboxes.
let start = Instant::now();
let mut delivered_bytes = 0;
for (sender, data) in senders.clone().zip(data.into_iter()) {
assert_eq!(data.len(), receivers.len());
for (receiver, data) in receivers.clone().zip(data.into_iter()) {
delivered_bytes += data.len();
(self.deliver)(data, sender, receiver);
}
}
self.delivery_usecs
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
self.delivered_bytes
.fetch_add(delivered_bytes, Ordering::Relaxed);

// Increment the receiver counters and deliver callbacks if necessary.
for receiver in receivers.clone() {
Expand Down Expand Up @@ -370,6 +358,20 @@ impl InnerExchange {
}
}

enum Mailbox<T> {
Serialized(Vec<u8>),
Plain(T),
}

impl<T> Mailbox<T> {
fn into_serialized(self) -> Option<Vec<u8>> {
match self {
Mailbox::Serialized(bytes) => Some(bytes),
Mailbox::Plain(_) => None,
}
}
}

/// `Exchange` is an N-to-N communication primitive that partitions data across
/// multiple concurrent threads.
///
Expand Down Expand Up @@ -417,14 +419,21 @@ pub(crate) struct Exchange<T> {
/// | | |RRRRR| | |
/// v |-----|-----|-----|-----|
/// ```
mailboxes: Arc<Vec<Mutex<Option<T>>>>,
mailboxes: Arc<Vec<Mutex<Option<Mailbox<T>>>>>,
serialize: Box<dyn Fn(T) -> Vec<u8> + Send + Sync>,
deserialize: Box<dyn Fn(Vec<u8>) -> T + Send + Sync>,

/// The amount of time we've spent calling `serialize`.
serialization_usecs: AtomicU64,

/// The number of bytes produced by `serialize`.
serialized_bytes: AtomicUsize,

/// The amount of time spent calling `deserialize`.
deserialization_usecs: AtomicU64,

/// The number of bytes passed to `deserialize`.
deserialized_bytes: AtomicUsize,
}

async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
Expand Down Expand Up @@ -474,15 +483,14 @@ where
deserialize: Box<dyn Fn(Vec<u8>) -> T + Send + Sync>,
) -> Self {
let npeers = Runtime::num_workers();
let mailboxes: Arc<Vec<Mutex<Option<T>>>> =
let mailboxes: Arc<Vec<Mutex<Option<Mailbox<T>>>>> =
Arc::new((0..npeers * npeers).map(|_| Mutex::new(None)).collect());
let mailboxes2: Arc<Vec<Mutex<Option<T>>>> = mailboxes.clone();
let mailboxes2 = mailboxes.clone();
let deliver = move |data: Vec<u8>, sender, receiver| {
let index: usize = sender * npeers + receiver;
let data = deserialize(data);
let mut mailbox = mailboxes2[index].lock().unwrap();
assert!((*mailbox).is_none());
*mailbox = Some(data);
*mailbox = Some(Mailbox::Serialized(data));
};

let inner = Arc::new(InnerExchange::new(exchange_id, deliver, clients));
Expand All @@ -496,8 +504,11 @@ where
inner,
mailboxes,
serialize,
deserialize,
serialization_usecs: AtomicU64::new(0),
serialized_bytes: AtomicUsize::new(0),
deserialization_usecs: AtomicU64::new(0),
deserialized_bytes: AtomicUsize::new(0),
}
}

Expand All @@ -507,7 +518,7 @@ where
}

/// Returns a reference to a mailbox for the sender/receiver pair.
fn mailbox(&self, sender: usize, receiver: usize) -> &Mutex<Option<T>> {
fn mailbox(&self, sender: usize, receiver: usize) -> &Mutex<Option<Mailbox<T>>> {
&self.mailboxes[self.inner.mailbox_index(sender, receiver)]
}

Expand Down Expand Up @@ -573,10 +584,11 @@ where
/// # Panics
///
/// Panics if `data` yields fewer than `self.npeers` items.
pub(crate) fn try_send_all<I>(self: &Arc<Self>, sender: usize, data: &mut I) -> bool
where
I: Iterator<Item = T> + Send,
{
pub(crate) fn try_send_all(
self: &Arc<Self>,
sender: usize,
data: impl Iterator<Item = T>,
) -> bool {
let npeers = self.inner.npeers;
if self.inner.sender_counters[sender]
.compare_exchange(npeers, 0, Ordering::AcqRel, Ordering::Acquire)
Expand All @@ -587,10 +599,23 @@ where

// Deliver all of the data to local mailboxes.
let local_workers = &self.inner.local_workers;
for receiver in 0..npeers {
*self.mailbox(sender, receiver).lock().unwrap() = data.next();
for (receiver, item) in (0..npeers).zip_eq(data.take(npeers)) {
let is_local = local_workers.contains(&receiver);
let mailbox = if is_local {
Mailbox::Plain(item)
} else {
let start = Instant::now();
let serialized = (self.serialize)(item);
self.serialization_usecs
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
self.serialized_bytes
.fetch_add(serialized.len(), Ordering::Relaxed);

Mailbox::Serialized(serialized)
};
*self.mailbox(sender, receiver).lock().unwrap() = Some(mailbox);

if local_workers.contains(&receiver) {
if is_local {
let old_counter =
self.inner.receiver_counters[receiver].fetch_add(1, Ordering::AcqRel);
if old_counter >= npeers - 1 {
Expand Down Expand Up @@ -619,7 +644,6 @@ where
// accumulate all of the data from our local `senders` to all
// of the `receivers` on that host.
let senders = &this.inner.local_workers;
let start = Instant::now();
for host in runtime.layout().other_hosts() {
let receivers = &host.workers;
let mut serialized_bytes = 0;
Expand All @@ -629,23 +653,20 @@ where
receivers
.clone()
.map(|receiver| {
let item = this
let serialized = this
.mailbox(sender, receiver)
.lock()
.unwrap()
.take()
.unwrap();
let serialized = (this.serialize)(item);
.unwrap()
.into_serialized()
.expect("remote mailboxes should always be serialized");
serialized_bytes += serialized.len();
serialized
})
.collect()
})
.collect();
this.serialization_usecs
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
this.serialized_bytes
.fetch_add(serialized_bytes, Ordering::Relaxed);

let client = this.inner.clients.connect(receivers.start).await;

Expand Down Expand Up @@ -706,12 +727,24 @@ where
}

for sender in 0..self.inner.npeers {
let data = self
let mailbox = self
.mailbox(sender, receiver)
.lock()
.unwrap()
.take()
.unwrap();
let data = match mailbox {
Mailbox::Plain(item) => item,
Mailbox::Serialized(bytes) => {
self.deserialized_bytes
.fetch_add(bytes.len(), Ordering::Relaxed);
let start = Instant::now();
let item = (self.deserialize)(bytes);
self.deserialization_usecs
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
item
}
};
cb(data);
if self.inner.local_workers.contains(&sender) {
let old_counter = self.inner.sender_counters[sender].fetch_add(1, Ordering::AcqRel);
Expand Down Expand Up @@ -1015,7 +1048,7 @@ where

let res = self.exchange.try_send_all(
self.worker_index,
&mut self.outputs.drain(..).map(|x| (x, self.flushed)),
self.outputs.drain(..).map(|x| (x, self.flushed)),
);
self.flushed = false;
debug_assert!(res);
Expand Down Expand Up @@ -1106,8 +1139,8 @@ where
meta.extend(metadata! {
OUTPUT_BATCHES_STATS => self.output_batch_stats.metadata(),
EXCHANGE_WAIT_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.total_wait_time.load(Ordering::Acquire))),
EXCHANGE_DESERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.inner.delivery_usecs.load(Ordering::Acquire))),
EXCHANGE_DESERIALIZED_BYTES => MetaItem::bytes(self.exchange.inner.delivered_bytes.load(Ordering::Acquire)),
EXCHANGE_DESERIALIZATION_TIME_SECONDS => MetaItem::Duration(Duration::from_micros(self.exchange.deserialization_usecs.load(Ordering::Acquire))),
EXCHANGE_DESERIALIZED_BYTES => MetaItem::bytes(self.exchange.deserialized_bytes.load(Ordering::Acquire)),
});
}

Expand Down Expand Up @@ -1352,9 +1385,8 @@ mod tests {

for round in 0..ROUNDS {
let output_data = vec![round; WORKERS];
let mut output_iter = output_data.clone().into_iter();
loop {
if exchange.try_send_all(Runtime::worker_index(), &mut output_iter) {
if exchange.try_send_all(Runtime::worker_index(), output_data.iter().copied()) {
break;
}

Expand Down
Loading