diff --git a/include/aws/io/socket.h b/include/aws/io/socket.h index 12745f20a..faf7d3d59 100644 --- a/include/aws/io/socket.h +++ b/include/aws/io/socket.h @@ -164,6 +164,8 @@ struct aws_socket_vtable { void *user_data); int (*socket_get_error_fn)(struct aws_socket *socket); bool (*socket_is_open_fn)(struct aws_socket *socket); + struct aws_byte_buf (*socket_get_protocol_fn)(const struct aws_socket *socket); + struct aws_string *(*socket_get_server_name_fn)(const struct aws_socket *socket); }; struct aws_socket { diff --git a/source/darwin/nw_socket.c b/source/darwin/nw_socket.c index e8ebb852c..e2dbffe8c 100644 --- a/source/darwin/nw_socket.c +++ b/source/darwin/nw_socket.c @@ -195,6 +195,7 @@ struct nw_socket { struct aws_string *host_name; struct aws_string *alpn_list; struct aws_tls_ctx *tls_ctx; + struct aws_byte_buf protocol_buf; struct { struct aws_mutex lock; @@ -644,6 +645,8 @@ static int s_socket_write_fn( void *user_data); static int s_socket_get_error_fn(struct aws_socket *socket); static bool s_socket_is_open_fn(struct aws_socket *socket); +static struct aws_byte_buf s_socket_get_protocol_fn(const struct aws_socket *socket); +static struct aws_string *s_socket_get_server_name_fn(const struct aws_socket *socket); static struct aws_socket_vtable s_vtable = { .socket_cleanup_fn = s_socket_cleanup_fn, @@ -661,6 +664,8 @@ static struct aws_socket_vtable s_vtable = { .socket_write_fn = s_socket_write_fn, .socket_get_error_fn = s_socket_get_error_fn, .socket_is_open_fn = s_socket_is_open_fn, + .socket_get_protocol_fn = s_socket_get_protocol_fn, + .socket_get_server_name_fn = s_socket_get_server_name_fn, }; static void s_schedule_next_read(struct nw_socket *socket); @@ -731,6 +736,8 @@ static void s_socket_impl_destroy(void *sock_ptr) { aws_string_destroy(nw_socket->alpn_list); } + aws_byte_buf_clean_up(&nw_socket->protocol_buf); + if (nw_socket->tls_ctx) { aws_tls_ctx_release(nw_socket->tls_ctx); nw_socket->tls_ctx = NULL; @@ -1391,6 +1398,35 @@ static int s_socket_connect_fn( nw_socket->timeout_args->connection_succeed = true; s_schedule_cancel_task(nw_socket, &nw_socket->timeout_args->task); } + + /* Check and store protocol for connection */ + if (nw_socket->tls_ctx) { + nw_protocol_metadata_t metadata = nw_connection_copy_protocol_metadata( + socket->io_handle.data.handle, nw_protocol_copy_tls_definition()); + if (metadata != NULL) { + sec_protocol_metadata_t sec_metadata = (sec_protocol_metadata_t)metadata; + + const char *negotiated_protocol = sec_protocol_metadata_get_negotiated_protocol(sec_metadata); + if (negotiated_protocol) { + nw_socket->protocol_buf.allocator = nw_socket->allocator; + size_t protocol_len = strlen(negotiated_protocol); + nw_socket->protocol_buf.buffer = + (uint8_t *)aws_mem_acquire(nw_socket->allocator, protocol_len); + nw_socket->protocol_buf.len = protocol_len; + nw_socket->protocol_buf.capacity = protocol_len; + memcpy(nw_socket->protocol_buf.buffer, negotiated_protocol, protocol_len); + + AWS_LOGF_DEBUG( + AWS_LS_IO_TLS, + "id=%p handle=%p: ALPN protocol set to: '%s'", + (void *)socket, + socket->io_handle.data.handle, + nw_socket->protocol_buf.buffer); + } + nw_release(metadata); + } + } + socket->state = CONNECTED_WRITE | CONNECTED_READ; nw_socket->setup_run = true; aws_ref_count_acquire(&nw_socket->ref_count); @@ -2091,6 +2127,16 @@ static bool s_socket_is_open_fn(struct aws_socket *socket) { return nw_socket->last_error == AWS_OP_SUCCESS; } +static struct aws_byte_buf s_socket_get_protocol_fn(const struct aws_socket *socket) { + struct nw_socket *nw_socket = socket->impl; + return nw_socket->protocol_buf; +} + +static struct aws_string *s_socket_get_server_name_fn(const struct aws_socket *socket) { + struct nw_socket *nw_socket = socket->impl; + return nw_socket->host_name; +} + void aws_socket_endpoint_init_local_address_for_test(struct aws_socket_endpoint *endpoint) { struct aws_uuid uuid; AWS_FATAL_ASSERT(aws_uuid_init(&uuid) == AWS_OP_SUCCESS); diff --git a/source/darwin/secure_transport_tls_channel_handler.c b/source/darwin/secure_transport_tls_channel_handler.c index 268611aea..0d2948c87 100644 --- a/source/darwin/secure_transport_tls_channel_handler.c +++ b/source/darwin/secure_transport_tls_channel_handler.c @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include @@ -829,17 +831,37 @@ static void s_gather_statistics(struct aws_channel_handler *handler, struct aws_ } struct aws_byte_buf aws_tls_handler_protocol(struct aws_channel_handler *handler) { +#if defined(AWS_USE_SECITEM) + /* Apple Network Framework's SecItem API handles both TCP and TLS aspects of a connection + * and an aws_channel using it does not have a TLS. The negotiated protocol is stored + * in the nw_socket and must be retrieved from the socket rather than a secure_transport_handler. */ + const struct aws_socket *socket = aws_socket_handler_get_socket(handler); + return socket->vtable->socket_get_protocol_fn(socket); +#endif /* AWS_USE_SECITEM */ struct secure_transport_handler *secure_transport_handler = handler->impl; + return secure_transport_handler->protocol; } struct aws_byte_buf aws_tls_handler_server_name(struct aws_channel_handler *handler) { + struct aws_string *server_name = NULL; +#if defined(AWS_USE_SECITEM) + /* Apple Network Framework's SecItem API handles both TCP and TLS aspects of a connection + * and an aws_channel using it does not have a TLS slot. The server_name is stored + * in the nw_socket and must be retrieved from the socket rather than a secure_transport_handler. */ + const struct aws_socket *socket = aws_socket_handler_get_socket(handler); + if (socket->vtable->socket_get_server_name_fn) { + server_name = socket->vtable->socket_get_server_name_fn(socket); + } +#else struct secure_transport_handler *secure_transport_handler = handler->impl; + server_name = secure_transport_handler->server_name +#endif const uint8_t *bytes = NULL; size_t len = 0; - if (secure_transport_handler->server_name) { - bytes = secure_transport_handler->server_name->bytes; - len = secure_transport_handler->server_name->len; + if (server_name) { + bytes = server_name->bytes; + len = server_name->len; } return aws_byte_buf_from_array(bytes, len); }