Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove function config allocations per invocation. #732

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[workspace]
resolver = "2"
members = [
"lambda-http",
"lambda-integration-tests",
Expand Down
2 changes: 1 addition & 1 deletion lambda-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ hyper = { version = "0.14.20", features = [
"server",
] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
serde = { version = "1", features = ["derive", "rc"] }
serde_json = "^1"
bytes = "1.0"
http = "0.2"
Expand Down
28 changes: 18 additions & 10 deletions lambda-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::{
future::Future,
marker::PhantomData,
panic,
sync::Arc,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::{Stream, StreamExt};
Expand Down Expand Up @@ -58,6 +59,8 @@ pub struct Config {
pub log_group: String,
}

type RefConfig = Arc<Config>;

impl Config {
/// Attempts to read configuration from environment variables.
pub fn from_env() -> Result<Self, Error> {
Expand Down Expand Up @@ -86,7 +89,7 @@ where

struct Runtime<C: Service<http::Uri> = HttpConnector> {
client: Client<C>,
config: Config,
config: RefConfig,
}

impl<C> Runtime<C>
Expand Down Expand Up @@ -127,8 +130,7 @@ where
continue;
}

let ctx: Context = Context::try_from(parts.headers)?;
let ctx: Context = ctx.with_config(&self.config);
let ctx: Context = Context::try_from((self.config.clone(), parts.headers))?;
let request_id = &ctx.request_id.clone();

let request_span = match &ctx.xray_trace_id {
Expand Down Expand Up @@ -263,7 +265,10 @@ where
trace!("Loading config from env");
let config = Config::from_env()?;
let client = Client::builder().build().expect("Unable to create a runtime client");
let runtime = Runtime { client, config };
let runtime = Runtime {
client,
config: Arc::new(config),
};

let client = &runtime.client;
let incoming = incoming(client);
Expand Down Expand Up @@ -294,15 +299,15 @@ mod endpoint_tests {
},
simulated,
types::Diagnostic,
Error, Runtime,
Config, Error, Runtime,
};
use futures::future::BoxFuture;
use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri};
use hyper::{server::conn::Http, service::service_fn, Body};
use lambda_runtime_api_client::Client;
use serde_json::json;
use simulated::DuplexStreamWrapper;
use std::{convert::TryFrom, env, marker::PhantomData};
use std::{convert::TryFrom, env, marker::PhantomData, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
select,
Expand Down Expand Up @@ -531,9 +536,12 @@ mod endpoint_tests {
if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() {
env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log");
}
let config = crate::Config::from_env().expect("Failed to read env vars");
let config = Config::from_env().expect("Failed to read env vars");

let runtime = Runtime { client, config };
let runtime = Runtime {
client,
config: Arc::new(config),
};
let client = &runtime.client;
let incoming = incoming(client).take(1);
runtime.run(incoming, f).await?;
Expand Down Expand Up @@ -568,13 +576,13 @@ mod endpoint_tests {

let f = crate::service_fn(func);

let config = crate::Config {
let config = Arc::new(Config {
function_name: "test_fn".to_string(),
memory: 128,
version: "1".to_string(),
log_stream: "test_stream".to_string(),
log_group: "test_log".to_string(),
};
});

let runtime = Runtime { client, config };
let client = &runtime.client;
Expand Down
81 changes: 49 additions & 32 deletions lambda-runtime/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Config, Error};
use crate::{Error, RefConfig};
use base64::prelude::*;
use bytes::Bytes;
use http::{HeaderMap, HeaderValue, StatusCode};
Expand Down Expand Up @@ -97,7 +97,7 @@ pub struct CognitoIdentity {
/// are populated using the [Lambda environment variables](https://docs.aws.amazon.com/lambda/latest/dg/current-supported-versions.html)
/// and [the headers returned by the poll request to the Runtime APIs](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html#runtimes-api-next).
#[non_exhaustive]
#[derive(Clone, Debug, Eq, PartialEq, Default, Serialize, Deserialize)]
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Context {
/// The AWS request ID generated by the Lambda service.
pub request_id: String,
Expand All @@ -117,12 +117,14 @@ pub struct Context {
/// Lambda function configuration from the local environment variables.
/// Includes information such as the function name, memory allocation,
/// version, and log streams.
pub env_config: Config,
pub env_config: RefConfig,
}

impl TryFrom<HeaderMap> for Context {
impl TryFrom<(RefConfig, HeaderMap)> for Context {
type Error = Error;
fn try_from(headers: HeaderMap) -> Result<Self, Self::Error> {
fn try_from(data: (RefConfig, HeaderMap)) -> Result<Self, Self::Error> {
let env_config = data.0;
let headers = data.1;
let client_context: Option<ClientContext> = if let Some(value) = headers.get("lambda-runtime-client-context") {
serde_json::from_str(value.to_str()?)?
} else {
Expand Down Expand Up @@ -158,13 +160,20 @@ impl TryFrom<HeaderMap> for Context {
.map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()),
client_context,
identity,
..Default::default()
env_config,
};

Ok(ctx)
}
}

impl Context {
/// The execution deadline for the current invocation.
pub fn deadline(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
}
}

/// Incoming Lambda request containing the event payload and context.
#[derive(Clone, Debug)]
pub struct LambdaEvent<T> {
Expand Down Expand Up @@ -273,6 +282,8 @@ where
#[cfg(test)]
mod test {
use super::*;
use crate::Config;
use std::sync::Arc;

#[test]
fn round_trip_lambda_error() {
Expand All @@ -292,6 +303,8 @@ mod test {

#[test]
fn context_with_expected_values_and_types_resolves() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
Expand All @@ -300,16 +313,18 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
}

#[test]
fn context_with_certain_missing_headers_still_resolves() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
}

Expand Down Expand Up @@ -338,7 +353,9 @@ mod test {
"lambda-runtime-client-context",
HeaderValue::from_str(&client_context_str).unwrap(),
);
let tried = Context::try_from(headers);

let config = Arc::new(Config::default());
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.client_context.is_some());
Expand All @@ -347,17 +364,20 @@ mod test {

#[test]
fn context_with_empty_client_context_resolves() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
assert!(tried.unwrap().client_context.is_some());
}

#[test]
fn context_with_identity_resolves() {
let config = Arc::new(Config::default());

let cognito_identity = CognitoIdentity {
identity_id: String::new(),
identity_pool_id: String::new(),
Expand All @@ -370,7 +390,7 @@ mod test {
"lambda-runtime-cognito-identity",
HeaderValue::from_str(&cognito_identity_str).unwrap(),
);
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.identity.is_some());
Expand All @@ -379,6 +399,8 @@ mod test {

#[test]
fn context_with_bad_deadline_type_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
Expand All @@ -390,86 +412,81 @@ mod test {
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
fn context_with_bad_client_context_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-client-context",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
fn context_with_empty_identity_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}"));
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
fn context_with_bad_identity_is_err() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-cognito-identity",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::try_from(headers);
let tried = Context::try_from((config, headers));
assert!(tried.is_err());
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_request_id_should_panic() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from(headers);
Context::try_from((config, headers));
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_deadline_should_panic() {
let config = Arc::new(Config::default());

let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from(headers);
}
}

impl Context {
/// Add environment details to the context by setting `env_config`.
pub fn with_config(self, config: &Config) -> Self {
Self {
env_config: config.clone(),
..self
}
}

/// The execution deadline for the current invocation.
pub fn deadline(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
Context::try_from((config, headers));
}
}