Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement access control with VPC endpoint checks and block for public internet / VPC #10143

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions proxy/src/auth/backend/console_redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,15 @@ async fn authenticate(
}
}

// Check if the access over the public internet is allowed, otherwise block. Note that
// the console redirect is not behind the VPC service endpoint, so we don't need to check
// the VPC endpoint ID.
if let Some(public_access_allowed) = db_info.public_access_allowed {
awarus marked this conversation as resolved.
Show resolved Hide resolved
if !public_access_allowed {
return Err(auth::AuthError::NetworkNotAllowed);
}
}

client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;

// This config should be self-contained, because we won't
Expand Down
124 changes: 104 additions & 20 deletions proxy/src/auth/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::{
self, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
self, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds,
CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
};
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::protocol2::ConnectionInfoExtra;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::proxy::NeonOptions;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
Expand Down Expand Up @@ -270,8 +272,8 @@ async fn auth_quirks(
Ok(info) => (info, None),
};

debug!("fetching user's authentication info");
let (allowed_ips, maybe_secret) = api.get_allowed_ips_and_secret(ctx, &info).await?;
debug!("fetching authentication info and allowlists");
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;

// check allowed list
if config.ip_allowlist_check_enabled
Expand All @@ -280,13 +282,47 @@ async fn auth_quirks(
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}

// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
if config.is_vpc_acccess_proxy {
if access_blocks.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
let extra = ctx.extra();
let incoming_vpc_endpoint_id = match extra {
None => "".to_string(),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
// Convert the vcpe_id to a string
match String::from_utf8(vpce_id.to_vec()) {
Ok(s) => s,
Err(_e) => "".to_string(),
}
}
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
if incoming_vpc_endpoint_id == "" {
stradig marked this conversation as resolved.
Show resolved Hide resolved
// This should never happen, would be a setup error on our side.
return Err(AuthError::MissingEndpointName);
}
let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
{
return Err(AuthError::vpc_endpoint_id_not_allowed(
incoming_vpc_endpoint_id,
));
}
} else {
if access_blocks.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
}

if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
return Err(AuthError::too_many_connections());
}
let cached_secret = match maybe_secret {
Some(secret) => secret,
None => api.get_role_secret(ctx, &info).await?,
};
let cached_secret = api.get_role_secret(ctx, &info).await?;
let (cached_entry, secret) = cached_secret.take_value();

let secret = if let Some(secret) = secret {
Expand Down Expand Up @@ -428,15 +464,37 @@ impl Backend<'_, ComputeUserInfo> {
}
}

pub(crate) async fn get_allowed_ips_and_secret(
pub(crate) async fn get_allowed_ips(
&self,
ctx: &RequestContext,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
) -> Result<CachedAllowedIps, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
}
}

pub(crate) async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_allowed_ips_and_secret(ctx, user_info).await
api.get_allowed_vpc_endpoint_ids(ctx, user_info).await
}
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
}
}

pub(crate) async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_block_public_or_vpc_access(ctx, user_info).await
}
Self::Local(_) => Ok(Cached::new_uncached(Default::default())),
}
}
}
Expand Down Expand Up @@ -482,7 +540,10 @@ mod tests {
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::{self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret};
use crate::control_plane::{
self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret,
};
use crate::proxy::NeonOptions;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
Expand All @@ -491,6 +552,8 @@ mod tests {

struct Auth {
ips: Vec<IpPattern>,
vpc_endpoint_ids: Vec<String>,
access_blocker_flags: AccessBlockerFlags,
secret: AuthSecret,
}

Expand All @@ -503,17 +566,31 @@ mod tests {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
}

async fn get_allowed_ips_and_secret(
async fn get_allowed_ips(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())))
}

async fn get_allowed_vpc_endpoint_ids(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(
self.vpc_endpoint_ids.clone(),
)))
}

async fn get_block_public_or_vpc_access(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<
(CachedAllowedIps, Option<CachedRoleSecret>),
control_plane::errors::GetAuthInfoError,
> {
Ok((
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
) -> Result<CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError> {
Ok(CachedAccessBlockerFlags::new_uncached(
self.access_blocker_flags.clone(),
))
}

Expand Down Expand Up @@ -543,6 +620,7 @@ mod tests {
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,
accept_jwts: false,
console_redirect_confirmation_timeout: std::time::Duration::from_secs(5),
Expand Down Expand Up @@ -610,6 +688,8 @@ mod tests {
let ctx = RequestContext::test();
let api = Auth {
ips: vec![],
vpc_endpoint_ids: vec![],
access_blocker_flags: AccessBlockerFlags::default(),
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};

Expand Down Expand Up @@ -687,6 +767,8 @@ mod tests {
let ctx = RequestContext::test();
let api = Auth {
ips: vec![],
vpc_endpoint_ids: vec![],
access_blocker_flags: AccessBlockerFlags::default(),
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};

Expand Down Expand Up @@ -739,6 +821,8 @@ mod tests {
let ctx = RequestContext::test();
let api = Auth {
ips: vec![],
vpc_endpoint_ids: vec![],
access_blocker_flags: AccessBlockerFlags::default(),
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};

Expand Down
25 changes: 25 additions & 0 deletions proxy/src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ pub(crate) enum AuthError {
)]
MissingEndpointName,

#[error(
"VPC endpoint ID is not specified. \
This endpoint requires a VPC endpoint ID to connect."
)]
MissingVPCEndpointId,

#[error("password authentication failed for user '{0}'")]
PasswordFailed(Box<str>),

Expand All @@ -69,6 +75,15 @@ pub(crate) enum AuthError {
)]
IpAddressNotAllowed(IpAddr),

#[error("This connection is trying to access this endpoint from a blocked network.")]
NetworkNotAllowed,

#[error(
"This VPC endpoint id {0} is not allowed to connect to this endpoint. \
Please add it to the allowed list in the Neon console."
)]
VpcEndpointIdNotAllowed(String),

#[error("Too many connections to this endpoint. Please try again later.")]
TooManyConnections,

Expand All @@ -95,6 +110,10 @@ impl AuthError {
AuthError::IpAddressNotAllowed(ip)
}

pub(crate) fn vpc_endpoint_id_not_allowed(id: String) -> Self {
AuthError::VpcEndpointIdNotAllowed(id)
}

pub(crate) fn too_many_connections() -> Self {
AuthError::TooManyConnections
}
Expand Down Expand Up @@ -122,8 +141,11 @@ impl UserFacingError for AuthError {
Self::BadAuthMethod(_) => self.to_string(),
Self::MalformedPassword(_) => self.to_string(),
Self::MissingEndpointName => self.to_string(),
Self::MissingVPCEndpointId => self.to_string(),
Self::Io(_) => "Internal error".to_string(),
Self::IpAddressNotAllowed(_) => self.to_string(),
Self::NetworkNotAllowed => self.to_string(),
Self::VpcEndpointIdNotAllowed(_) => self.to_string(),
Self::TooManyConnections => self.to_string(),
Self::UserTimeout(_) => self.to_string(),
Self::ConfirmationTimeout(_) => self.to_string(),
Expand All @@ -142,8 +164,11 @@ impl ReportableError for AuthError {
Self::BadAuthMethod(_) => crate::error::ErrorKind::User,
Self::MalformedPassword(_) => crate::error::ErrorKind::User,
Self::MissingEndpointName => crate::error::ErrorKind::User,
Self::MissingVPCEndpointId => crate::error::ErrorKind::User,
Self::Io(_) => crate::error::ErrorKind::ClientDisconnect,
Self::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
Self::NetworkNotAllowed => crate::error::ErrorKind::User,
Self::VpcEndpointIdNotAllowed(_) => crate::error::ErrorKind::User,
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
Self::UserTimeout(_) => crate::error::ErrorKind::User,
Self::ConfirmationTimeout(_) => crate::error::ErrorKind::User,
Expand Down
1 change: 1 addition & 0 deletions proxy/src/bin/local_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,
accept_jwts: true,
console_redirect_confirmation_timeout: Duration::ZERO,
Expand Down
1 change: 1 addition & 0 deletions proxy/src/bin/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
console_redirect_confirmation_timeout: args.webauth_confirmation_timeout,
Expand Down
Loading
Loading