Skip to content

Commit

Permalink
Add ws_proxy option for connecting gateway through a proxy.
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzzRuby committed Jan 12, 2024
1 parent d484498 commit b596a76
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 8 deletions.
21 changes: 21 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub struct ClientBuilder {
event_handlers: Vec<Arc<dyn EventHandler>>,

Check failure on line 78 in src/client/mod.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/client/mod.rs
raw_event_handlers: Vec<Arc<dyn RawEventHandler>>,
presence: PresenceData,
ws_proxy: Option<String>
}

#[cfg(feature = "gateway")]
Expand Down Expand Up @@ -112,6 +113,7 @@ impl ClientBuilder {
event_handlers: vec![],

Check failure on line 113 in src/client/mod.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/client/mod.rs
raw_event_handlers: vec![],
presence: PresenceData::default(),
ws_proxy: None
}
}

Expand Down Expand Up @@ -293,6 +295,23 @@ impl ClientBuilder {
pub fn get_presence(&self) -> &PresenceData {
&self.presence
}

/// Sets a http proxy for the websocket connection.
pub fn ws_proxy<T: Into<String>>(mut self, proxy: T) -> Self {
self.ws_proxy = Some(proxy.into());
self
}

/// Remove websocket proxy.
pub fn no_ws_proxy(mut self) -> Self {
self.ws_proxy = None;

Check failure on line 307 in src/client/mod.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/client/mod.rs
self
}

/// Gets the websocket proxy. See [`Self::ws_proxy`] for more info.
pub fn get_ws_proxy(&self) -> Option<&str> {
self.ws_proxy.as_ref().map(|x|x.as_str())
}
}

#[cfg(feature = "gateway")]
Expand All @@ -310,6 +329,7 @@ impl IntoFuture for ClientBuilder {
let raw_event_handlers = self.raw_event_handlers;
let intents = self.intents;
let presence = self.presence;
let ws_proxy = self.ws_proxy;

let mut http = self.http;

Expand Down Expand Up @@ -351,6 +371,7 @@ impl IntoFuture for ClientBuilder {
#[cfg(feature = "voice")]
voice_manager: voice_manager.as_ref().map(Arc::clone),
ws_url: Arc::clone(&ws_url),
ws_proxy,
shard_total,
#[cfg(feature = "cache")]
cache: Arc::clone(&cache),
Expand Down
2 changes: 2 additions & 0 deletions src/gateway/bridge/shard_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl ShardManager {
#[cfg(feature = "voice")]
voice_manager: opt.voice_manager,
ws_url: opt.ws_url,
ws_proxy: opt.ws_proxy,
shard_total: opt.shard_total,
#[cfg(feature = "cache")]
cache: opt.cache,
Expand Down Expand Up @@ -365,6 +366,7 @@ pub struct ShardManagerOptions {
#[cfg(feature = "voice")]
pub voice_manager: Option<Arc<dyn VoiceGatewayManager>>,
pub ws_url: Arc<str>,
pub ws_proxy: Option<String>,
pub shard_total: NonZeroU16,
#[cfg(feature = "cache")]
pub cache: Arc<Cache>,
Expand Down
2 changes: 2 additions & 0 deletions src/gateway/bridge/shard_queuer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub struct ShardQueuer {
pub voice_manager: Option<Arc<dyn VoiceGatewayManager + 'static>>,
/// A copy of the URL to use to connect to the gateway.
pub ws_url: Arc<str>,
pub ws_proxy: Option<String>,
/// The total amount of shards to start.
pub shard_total: NonZeroU16,
#[cfg(feature = "cache")]
Expand Down Expand Up @@ -171,6 +172,7 @@ impl ShardQueuer {
async fn start(&mut self, shard_id: ShardId) -> Result<()> {

Check failure on line 172 in src/gateway/bridge/shard_queuer.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/gateway/bridge/shard_queuer.rs
let mut shard = Shard::new(
Arc::clone(&self.ws_url),
self.ws_proxy.as_ref().map(|x|x.as_str()),
self.http.token(),
ShardInfo::new(shard_id, self.shard_total),
self.intents,
Expand Down
25 changes: 21 additions & 4 deletions src/gateway/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub struct Shard {
pub started: Instant,
pub token: String,
ws_url: Arc<str>,
ws_proxy: Option<String>,
pub intents: GatewayIntents,
}

Expand Down Expand Up @@ -120,12 +121,13 @@ impl Shard {
/// TLS error.
pub async fn new(
ws_url: Arc<str>,
ws_proxy: Option<&str>,
token: &str,
shard_info: ShardInfo,
intents: GatewayIntents,
presence: Option<PresenceData>,
) -> Result<Shard> {
let client = connect(&ws_url).await?;
let client = connect(&ws_url, ws_proxy).await?;

let presence = presence.unwrap_or_default();
let last_heartbeat_sent = None;
Expand All @@ -151,6 +153,7 @@ impl Shard {
session_id,

Check failure on line 153 in src/gateway/shard.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/gateway/shard.rs
shard_info,
ws_url,
ws_proxy: ws_proxy.map(|x|x.to_owned()),
intents,
})
}
Expand Down Expand Up @@ -684,7 +687,8 @@ impl Shard {
// Hello is received.

Check failure on line 687 in src/gateway/shard.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/gateway/shard.rs
self.stage = ConnectionStage::Connecting;
self.started = Instant::now();
let client = connect(&self.ws_url).await?;
let proxy = self.ws_proxy.as_ref().map(|x|x.as_str());
let client = connect(&self.ws_url, proxy).await?;
self.stage = ConnectionStage::Handshake;

Ok(client)
Expand Down Expand Up @@ -741,13 +745,26 @@ impl Shard {
}
}

async fn connect(base_url: &str) -> Result<WsClient> {
async fn connect(base_url: &str, proxy: Option<&str>) -> Result<WsClient> {
let url =
Url::parse(&format!("{base_url}?v={}", constants::GATEWAY_VERSION)).map_err(|why| {
warn!("Error building gateway URL with base `{}`: {:?}", base_url, why);

Error::Gateway(GatewayError::BuildingUrl)
})?;

WsClient::connect(url).await
let proxy_url = proxy.map(|proxy| {
Url::parse(proxy).map_err(|why| {
warn!("Error building proxy URL with base `{}`: {:?}", base_url, why);

Error::Gateway(GatewayError::BuildingUrl)
})
});

Check failure on line 763 in src/gateway/shard.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/gateway/shard.rs
let proxy = match proxy_url {
Some(result) => Some(result?),
None => None
};

WsClient::connect(url, proxy).await
}
63 changes: 59 additions & 4 deletions src/gateway/ws.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::env::consts;
use std::io::ErrorKind;
#[cfg(feature = "client")]
use std::io::Read;
use std::time::SystemTime;
Expand All @@ -8,6 +9,7 @@ use flate2::read::ZlibDecoder;
use futures::SinkExt;
#[cfg(feature = "client")]
use futures::StreamExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(feature = "client")]
use tokio::time::{timeout, Duration};
Expand All @@ -17,11 +19,11 @@ use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
#[cfg(feature = "client")]

Check failure on line 19 in src/gateway/ws.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/gateway/ws.rs
use tokio_tungstenite::tungstenite::Error as WsError;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
use tokio_tungstenite::{client_async_tls_with_config, connect_async_with_config, MaybeTlsStream, WebSocketStream};
#[cfg(feature = "client")]
use tracing::warn;
use tracing::{debug, trace};
use url::Url;
use url::{Position, Url};

use super::{ActivityData, ChunkGuildFilter, PresenceData};
use crate::constants::{self, Opcode};
Expand Down Expand Up @@ -101,13 +103,66 @@ const TIMEOUT: Duration = Duration::from_millis(500);
const DECOMPRESSION_MULTIPLIER: usize = 3;

Check failure on line 103 in src/gateway/ws.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/gateway/ws.rs

impl WsClient {
pub(crate) async fn connect(url: Url) -> Result<Self> {
async fn connect_with_proxy_async(target_url: &Url, proxy_url: &Url) -> std::result::Result<TcpStream, std::io::Error> {
let proxy_addr = &proxy_url[Position::BeforeHost..Position::AfterPort];
if proxy_url.scheme() != "http" && proxy_url.scheme() != "https" {
return Err(std::io::Error::new(ErrorKind::Unsupported, "unknown proxy scheme"));
}

Check failure on line 110 in src/gateway/ws.rs

View workflow job for this annotation

GitHub Actions / Format

Diff in /home/runner/work/serenity/serenity/src/gateway/ws.rs

let host = target_url.host_str()
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target host"))?;
let port = target_url.port()
.or_else(|| match target_url.scheme() {
"wss" => Some(443),
"ws" => Some(80),
_ => None,
})
.ok_or_else(|| std::io::Error::new(ErrorKind::Unsupported, "unknown target scheme"))?;
let mut tcp_stream = TcpStream::connect(proxy_addr).await?;

let buf = format!("CONNECT {0}:{1} HTTP/1.1\r\nHost: {0}:{1}\r\n\r\n", host, port).into_bytes();

tcp_stream.write_all(&buf).await?;

let mut all_buf = Vec::new();

loop {
let mut buf = [0; 1024];
let n = tcp_stream.read(&mut buf).await?;
if n == 0 {
return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "no bytes in tunnel"));
}
all_buf.extend_from_slice(&buf[..n]);

if !all_buf.starts_with(b"HTTP/1.1 200") &&
!all_buf.starts_with(b"HTTP/1.0 200") {
return Err(std::io::Error::new(ErrorKind::Other, "tunnel error"));
}
if all_buf.ends_with(b"\r\n\r\n") {
return Ok(tcp_stream);
}
if all_buf.len() > 4096 {
return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "too many bytes in tunnel"));
}
}
}

pub(crate) async fn connect(url: Url, proxy: Option<Url>) -> Result<Self> {
let config = WebSocketConfig {
max_message_size: None,
max_frame_size: None,
..Default::default()
};
let (stream, _) = connect_async_with_config(url, Some(config), false).await?;
let (stream, _) = match proxy {
None => {
connect_async_with_config(url, Some(config), false).await?
},
Some(proxy) => {
let tls_stream = Self::connect_with_proxy_async(&url, &proxy).await?;
tls_stream.set_nodelay(true)?;
client_async_tls_with_config(url, tls_stream, Some(config), None).await?
}
};

Ok(Self(stream))
}
Expand Down

0 comments on commit b596a76

Please sign in to comment.