Skip to content

Commit

Permalink
Improve Download Memory Usage (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
waahm7 authored Oct 8, 2024
1 parent 31550a0 commit 3dd20c8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 67 deletions.
27 changes: 20 additions & 7 deletions aws-s3-transfer-manager/src/operation/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use aws_smithy_types::byte_stream::ByteStream;
use body::Body;
use discovery::discover_obj;
use service::{distribute_work, ChunkResponse};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
Expand Down Expand Up @@ -96,9 +97,8 @@ fn handle_discovery_chunk(
completed: &mpsc::Sender<Result<ChunkResponse, crate::error::Error>>,
permit: OwnedWorkPermit,
) -> u64 {
let mut start_seq = 0;

if let Some(stream) = initial_chunk {
let seq = handle.ctx.next_seq();
let completed = completed.clone();
// spawn a task to actually read the discovery chunk without waiting for it so we
// can get started sooner on any remaining work (if any)
Expand All @@ -107,7 +107,7 @@ fn handle_discovery_chunk(
.collect()
.await
.map(|aggregated| ChunkResponse {
seq: start_seq,
seq,
data: Some(aggregated),
})
.map_err(error::discovery_failed);
Expand All @@ -122,25 +122,38 @@ fn handle_discovery_chunk(
);
}
});
start_seq = 1;
}
start_seq
handle.ctx.current_seq()
}

/// Download operation specific state
#[derive(Debug)]
pub(crate) struct DownloadState {}
pub(crate) struct DownloadState {
current_seq: AtomicU64,
}

type DownloadContext = TransferContext<DownloadState>;

impl DownloadContext {
fn new(handle: Arc<crate::client::Handle>) -> Self {
let state = Arc::new(DownloadState {});
let state = Arc::new(DownloadState {
current_seq: AtomicU64::new(0),
});
TransferContext { handle, state }
}

/// The target part size to use for this download
fn target_part_size_bytes(&self) -> u64 {
self.handle.download_part_size_bytes()
}

/// Returns the next seq to download
fn next_seq(&self) -> u64 {
self.state.current_seq.fetch_add(1, Ordering::SeqCst)
}

/// Returns the current seq
fn current_seq(&self) -> u64 {
self.state.current_seq.load(Ordering::SeqCst)
}
}
94 changes: 34 additions & 60 deletions aws-s3-transfer-manager/src/operation/download/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,39 @@ use super::{DownloadHandle, DownloadInput, DownloadInputBuilder};
#[derive(Debug, Clone)]
pub(super) struct DownloadChunkRequest {
pub(super) ctx: DownloadContext,
pub(super) request: ChunkRequest,
pub(super) remaining: RangeInclusive<u64>,
pub(super) input: DownloadInputBuilder,
pub(super) start_seq: u64,
}

fn next_chunk(
seq: u64,
remaining: RangeInclusive<u64>,
part_size: u64,
start_seq: u64,
input: DownloadInputBuilder,
) -> DownloadInputBuilder {
let start = remaining.start() + ((seq - start_seq) * part_size);
let end_inclusive = cmp::min(start + part_size - 1, *remaining.end());
input.range(header::Range::bytes_inclusive(start, end_inclusive))
}

/// handler (service fn) for a single chunk
async fn download_chunk_handler(
request: DownloadChunkRequest,
) -> Result<ChunkResponse, error::Error> {
let ctx = request.ctx;
let request = request.request;

let op = request.input.into_sdk_operation(ctx.client());

let seq = ctx.next_seq();
let part_size = ctx.handle.download_part_size_bytes();
let input = next_chunk(
seq,
request.remaining,
part_size,
request.start_seq,
request.input,
);

let op = input.into_sdk_operation(ctx.client());
let mut resp = op
.send()
.await
Expand All @@ -43,12 +64,12 @@ async fn download_chunk_handler(

let bytes = body
.collect()
.instrument(tracing::debug_span!("collect-body", seq = request.seq))
.instrument(tracing::debug_span!("collect-body", seq = seq))
.await
.map_err(error::from_kind(error::ErrorKind::ChunkFailed))?;

Ok(ChunkResponse {
seq: request.seq,
seq,
data: Some(bytes),
})
}
Expand All @@ -68,23 +89,6 @@ pub(super) fn chunk_service(
.service(svc)
}

// FIXME - should probably be enum ChunkRequest { Range(..), Part(..) } or have an inner field like such
#[derive(Debug, Clone)]
pub(super) struct ChunkRequest {
// byte range to download
pub(super) range: RangeInclusive<u64>,
pub(super) input: DownloadInputBuilder,
// sequence number
pub(super) seq: u64,
}

impl ChunkRequest {
/// Size of this chunk request in bytes
pub(super) fn size(&self) -> u64 {
self.range.end() - self.range.start() + 1
}
}

#[derive(Debug, Clone)]
pub(crate) struct ChunkResponse {
// TODO(aws-sdk-rust#1159, design) - consider PartialOrd for ChunkResponse and hiding `seq` as internal only detail
Expand All @@ -110,31 +114,18 @@ pub(super) fn distribute_work(
start_seq: u64,
comp_tx: mpsc::Sender<Result<ChunkResponse, error::Error>>,
) {
let end = *remaining.end();
let mut pos = *remaining.start();
let mut remaining = end - pos + 1;
let mut seq = start_seq;

let svc = chunk_service(&handle.ctx);

let part_size = handle.ctx.target_part_size_bytes();
let input: DownloadInputBuilder = input.into();

while remaining > 0 {
let start = pos;
let end_inclusive = cmp::min(pos + part_size - 1, end);

let chunk_req = next_chunk(start, end_inclusive, seq, input.clone());
tracing::trace!(
"distributing chunk(size={}): {:?}",
chunk_req.size(),
chunk_req
);
let chunk_size = chunk_req.size();

let size = *remaining.end() - *remaining.start() + 1;
let num_parts = size.div_ceil(part_size);
for seq in 0..num_parts {
let req = DownloadChunkRequest {
ctx: handle.ctx.clone(),
request: chunk_req,
remaining: remaining.clone(),
input: input.clone(),
start_seq,
};

let svc = svc.clone();
Expand All @@ -147,25 +138,8 @@ pub(super) fn distribute_work(
}
}
.instrument(tracing::debug_span!("download-chunk", seq = seq));

handle.tasks.spawn(task);

seq += 1;
remaining -= chunk_size;
tracing::trace!("remaining = {}", remaining);
pos += chunk_size;
}

tracing::trace!("work fully distributed");
}

fn next_chunk(
start: u64,
end_inclusive: u64,
seq: u64,
input: DownloadInputBuilder,
) -> ChunkRequest {
let range = start..=end_inclusive;
let input = input.range(header::Range::bytes_inclusive(start, end_inclusive));
ChunkRequest { seq, range, input }
}

0 comments on commit 3dd20c8

Please sign in to comment.