Skip to content

Commit

Permalink
add to socket vtable to allow access to protocol and server_name
Browse files Browse the repository at this point in the history
  • Loading branch information
sbSteveK committed Nov 4, 2024
1 parent 3de0e84 commit a1226e9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/aws/io/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ struct aws_socket_vtable {
void *user_data);
int (*socket_get_error_fn)(struct aws_socket *socket);
bool (*socket_is_open_fn)(struct aws_socket *socket);
struct aws_byte_buf (*socket_get_protocol_fn)(const struct aws_socket *socket);
struct aws_string *(*socket_get_server_name_fn)(const struct aws_socket *socket);
};

struct aws_socket {
Expand Down
46 changes: 46 additions & 0 deletions source/darwin/nw_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ struct nw_socket {
struct aws_string *host_name;
struct aws_string *alpn_list;
struct aws_tls_ctx *tls_ctx;
struct aws_byte_buf protocol_buf;

struct {
struct aws_mutex lock;
Expand Down Expand Up @@ -644,6 +645,8 @@ static int s_socket_write_fn(
void *user_data);
static int s_socket_get_error_fn(struct aws_socket *socket);
static bool s_socket_is_open_fn(struct aws_socket *socket);
static struct aws_byte_buf s_socket_get_protocol_fn(const struct aws_socket *socket);
static struct aws_string *s_socket_get_server_name_fn(const struct aws_socket *socket);

static struct aws_socket_vtable s_vtable = {
.socket_cleanup_fn = s_socket_cleanup_fn,
Expand All @@ -661,6 +664,8 @@ static struct aws_socket_vtable s_vtable = {
.socket_write_fn = s_socket_write_fn,
.socket_get_error_fn = s_socket_get_error_fn,
.socket_is_open_fn = s_socket_is_open_fn,
.socket_get_protocol_fn = s_socket_get_protocol_fn,
.socket_get_server_name_fn = s_socket_get_server_name_fn,
};

static void s_schedule_next_read(struct nw_socket *socket);
Expand Down Expand Up @@ -731,6 +736,8 @@ static void s_socket_impl_destroy(void *sock_ptr) {
aws_string_destroy(nw_socket->alpn_list);
}

aws_byte_buf_clean_up(&nw_socket->protocol_buf);

if (nw_socket->tls_ctx) {
aws_tls_ctx_release(nw_socket->tls_ctx);
nw_socket->tls_ctx = NULL;
Expand Down Expand Up @@ -1391,6 +1398,35 @@ static int s_socket_connect_fn(
nw_socket->timeout_args->connection_succeed = true;
s_schedule_cancel_task(nw_socket, &nw_socket->timeout_args->task);
}

/* Check and store protocol for connection */
if (nw_socket->tls_ctx) {
nw_protocol_metadata_t metadata = nw_connection_copy_protocol_metadata(
socket->io_handle.data.handle, nw_protocol_copy_tls_definition());
if (metadata != NULL) {
sec_protocol_metadata_t sec_metadata = (sec_protocol_metadata_t)metadata;

const char *negotiated_protocol = sec_protocol_metadata_get_negotiated_protocol(sec_metadata);
if (negotiated_protocol) {
nw_socket->protocol_buf.allocator = nw_socket->allocator;
size_t protocol_len = strlen(negotiated_protocol);
nw_socket->protocol_buf.buffer =
(uint8_t *)aws_mem_acquire(nw_socket->allocator, protocol_len);
nw_socket->protocol_buf.len = protocol_len;
nw_socket->protocol_buf.capacity = protocol_len;
memcpy(nw_socket->protocol_buf.buffer, negotiated_protocol, protocol_len);

AWS_LOGF_DEBUG(
AWS_LS_IO_TLS,
"id=%p handle=%p: ALPN protocol set to: '%s'",
(void *)socket,
socket->io_handle.data.handle,
nw_socket->protocol_buf.buffer);
}
nw_release(metadata);
}
}

socket->state = CONNECTED_WRITE | CONNECTED_READ;
nw_socket->setup_run = true;
aws_ref_count_acquire(&nw_socket->ref_count);
Expand Down Expand Up @@ -2091,6 +2127,16 @@ static bool s_socket_is_open_fn(struct aws_socket *socket) {
return nw_socket->last_error == AWS_OP_SUCCESS;
}

static struct aws_byte_buf s_socket_get_protocol_fn(const struct aws_socket *socket) {
struct nw_socket *nw_socket = socket->impl;
return nw_socket->protocol_buf;
}

static struct aws_string *s_socket_get_server_name_fn(const struct aws_socket *socket) {
struct nw_socket *nw_socket = socket->impl;
return nw_socket->host_name;
}

void aws_socket_endpoint_init_local_address_for_test(struct aws_socket_endpoint *endpoint) {
struct aws_uuid uuid;
AWS_FATAL_ASSERT(aws_uuid_init(&uuid) == AWS_OP_SUCCESS);
Expand Down
28 changes: 25 additions & 3 deletions source/darwin/secure_transport_tls_channel_handler.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <aws/io/private/aws_apple_network_framework.h>
#include <aws/io/private/pki_utils.h>
#include <aws/io/private/tls_channel_handler_shared.h>
#include <aws/io/socket.h>
#include <aws/io/socket_channel_handler.h>
#include <aws/io/statistics.h>

#include <aws/io/logging.h>
Expand Down Expand Up @@ -829,17 +831,37 @@ static void s_gather_statistics(struct aws_channel_handler *handler, struct aws_
}

struct aws_byte_buf aws_tls_handler_protocol(struct aws_channel_handler *handler) {
#if defined(AWS_USE_SECITEM)
/* Apple Network Framework's SecItem API handles both TCP and TLS aspects of a connection
* and an aws_channel using it does not have a TLS. The negotiated protocol is stored
* in the nw_socket and must be retrieved from the socket rather than a secure_transport_handler. */
const struct aws_socket *socket = aws_socket_handler_get_socket(handler);
return socket->vtable->socket_get_protocol_fn(socket);
#endif /* AWS_USE_SECITEM */
struct secure_transport_handler *secure_transport_handler = handler->impl;

return secure_transport_handler->protocol;
}

struct aws_byte_buf aws_tls_handler_server_name(struct aws_channel_handler *handler) {
struct aws_string *server_name = NULL;
#if defined(AWS_USE_SECITEM)
/* Apple Network Framework's SecItem API handles both TCP and TLS aspects of a connection
* and an aws_channel using it does not have a TLS slot. The server_name is stored
* in the nw_socket and must be retrieved from the socket rather than a secure_transport_handler. */
const struct aws_socket *socket = aws_socket_handler_get_socket(handler);
if (socket->vtable->socket_get_server_name_fn) {
server_name = socket->vtable->socket_get_server_name_fn(socket);
}
#else
struct secure_transport_handler *secure_transport_handler = handler->impl;
server_name = secure_transport_handler->server_name
#endif
const uint8_t *bytes = NULL;
size_t len = 0;
if (secure_transport_handler->server_name) {
bytes = secure_transport_handler->server_name->bytes;
len = secure_transport_handler->server_name->len;
if (server_name) {
bytes = server_name->bytes;
len = server_name->len;
}
return aws_byte_buf_from_array(bytes, len);
}
Expand Down

0 comments on commit a1226e9

Please sign in to comment.