[MSTSC] Add schannel based SSL implementation. CORE-9321

svn path=/trunk/; revision=73368
This commit is contained in:
Peter Hater 2016-11-24 10:56:09 +00:00
parent c253b1d7bc
commit a26ebfa365
2 changed files with 452 additions and 119 deletions

View file

@ -30,10 +30,11 @@ list(APPEND SOURCE
uimain.h
precomp.h)
add_definitions(-DWITH_SSL)
file(GLOB mstsc_rc_deps res/*.*)
add_rc_deps(rdc.rc ${mstsc_rc_deps})
add_executable(mstsc ${SOURCE} rdc.rc)
set_module_type(mstsc win32gui UNICODE)
add_importlibs(mstsc user32 gdi32 comctl32 ws2_32 advapi32 shell32 ole32 comdlg32 msvcrt kernel32)
add_importlibs(mstsc user32 gdi32 comctl32 ws2_32 crypt32 secur32 advapi32 shell32 ole32 comdlg32 msvcrt kernel32)
add_pch(mstsc precomp.h SOURCE)
add_cd_file(TARGET mstsc DESTINATION reactos/system32 FOR all)

View file

@ -20,6 +20,10 @@
*/
#include "precomp.h"
#ifdef WITH_SSL
#include <sspi.h>
#include <schannel.h>
#endif
#ifdef _WIN32
#define socklen_t int
@ -45,9 +49,25 @@
#endif
#ifdef WITH_SSL
typedef struct
{
CtxtHandle ssl_ctx;
SecPkgContext_StreamSizes ssl_sizes;
char *ssl_buf;
char *extra_buf;
size_t extra_len;
char *peek_msg;
char *peek_msg_mem;
size_t peek_len;
DWORD security_flags;
} netconn_t;
static char * g_ssl_server = NULL;
static RD_BOOL g_ssl_initialized = False;
static SSL *g_ssl = NULL;
static SSL_CTX *g_ssl_ctx = NULL;
static RD_BOOL cred_handle_initialized = False;
static RD_BOOL have_compat_cred_handle = False;
static SecHandle cred_handle, compat_cred_handle;
static netconn_t g_ssl1;
static netconn_t *g_ssl = NULL;
#endif /* WITH_SSL */
static int g_sock;
static struct stream g_in;
@ -84,13 +104,171 @@ tcp_init(uint32 maxlen)
return result;
}
#ifdef WITH_SSL
RD_BOOL send_ssl_chunk(const void *msg, size_t size)
{
SecBuffer bufs[4] = {
{g_ssl->ssl_sizes.cbHeader, SECBUFFER_STREAM_HEADER, g_ssl->ssl_buf},
{size, SECBUFFER_DATA, g_ssl->ssl_buf+g_ssl->ssl_sizes.cbHeader},
{g_ssl->ssl_sizes.cbTrailer, SECBUFFER_STREAM_TRAILER, g_ssl->ssl_buf+g_ssl->ssl_sizes.cbHeader+size},
{0, SECBUFFER_EMPTY, NULL}
};
SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
SECURITY_STATUS res;
int tcp_res;
memcpy(bufs[1].pvBuffer, msg, size);
res = EncryptMessage(&g_ssl->ssl_ctx, 0, &buf_desc, 0);
if (res != SEC_E_OK)
{
error("EncryptMessage failed: %d\n", res);
return False;
}
tcp_res = send(g_sock, g_ssl->ssl_buf, bufs[0].cbBuffer+bufs[1].cbBuffer+bufs[2].cbBuffer, 0);
if (tcp_res < 1)
{
error("send failed: %d (%s)\n", tcp_res, TCP_STRERROR);
return False;
}
return True;
}
DWORD read_ssl_chunk(void *buf, SIZE_T buf_size, BOOL blocking, SIZE_T *ret_size, BOOL *eof)
{
const SIZE_T ssl_buf_size = g_ssl->ssl_sizes.cbHeader+g_ssl->ssl_sizes.cbMaximumMessage+g_ssl->ssl_sizes.cbTrailer;
SecBuffer bufs[4];
SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs};
SSIZE_T size, buf_len = 0;
int i;
SECURITY_STATUS res;
//assert(conn->extra_len < ssl_buf_size);
if (g_ssl->extra_len)
{
memcpy(g_ssl->ssl_buf, g_ssl->extra_buf, g_ssl->extra_len);
buf_len = g_ssl->extra_len;
g_ssl->extra_len = 0;
xfree(g_ssl->extra_buf);
g_ssl->extra_buf = NULL;
}
size = recv(g_sock, g_ssl->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
if (size < 0)
{
if (!buf_len)
{
if (size == -1 && TCP_BLOCKS)
{
return WSAEWOULDBLOCK;
}
error("recv failed: %d (%s)\n", size, TCP_STRERROR);
return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
}
}
else
{
buf_len += size;
}
if (!buf_len)
{
*eof = TRUE;
*ret_size = 0;
return ERROR_SUCCESS;
}
*eof = FALSE;
do
{
memset(bufs, 0, sizeof(bufs));
bufs[0].BufferType = SECBUFFER_DATA;
bufs[0].cbBuffer = buf_len;
bufs[0].pvBuffer = g_ssl->ssl_buf;
res = DecryptMessage(&g_ssl->ssl_ctx, &buf_desc, 0, NULL);
switch (res)
{
case SEC_E_OK:
break;
case SEC_I_CONTEXT_EXPIRED:
*eof = TRUE;
return ERROR_SUCCESS;
case SEC_E_INCOMPLETE_MESSAGE:
//assert(buf_len < ssl_buf_size);
size = recv(g_sock, g_ssl->ssl_buf+buf_len, ssl_buf_size-buf_len, 0);
if (size < 1)
{
if (size == -1 && TCP_BLOCKS)
{
/* FIXME: Optimize extra_buf usage. */
g_ssl->extra_buf = xmalloc(buf_len);
if (!g_ssl->extra_buf)
return ERROR_NOT_ENOUGH_MEMORY;
g_ssl->extra_len = buf_len;
memcpy(g_ssl->extra_buf, g_ssl->ssl_buf, g_ssl->extra_len);
return WSAEWOULDBLOCK;
}
error("recv failed: %d (%s)\n", size, TCP_STRERROR);
return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
}
buf_len += size;
continue;
default:
error("DecryptMessage failed: %d\n", res);
return -1;//ERROR_INTERNET_CONNECTION_ABORTED;
}
}
while (res != SEC_E_OK);
for (i=0; i < sizeof(bufs)/sizeof(*bufs); i++)
{
if (bufs[i].BufferType == SECBUFFER_DATA)
{
size = min(buf_size, bufs[i].cbBuffer);
memcpy(buf, bufs[i].pvBuffer, size);
if (size < bufs[i].cbBuffer)
{
//assert(!conn->peek_len);
g_ssl->peek_msg_mem = g_ssl->peek_msg = xmalloc(bufs[i].cbBuffer - size);
if (!g_ssl->peek_msg)
return ERROR_NOT_ENOUGH_MEMORY;
g_ssl->peek_len = bufs[i].cbBuffer-size;
memcpy(g_ssl->peek_msg, (char*)bufs[i].pvBuffer+size, g_ssl->peek_len);
}
*ret_size = size;
}
}
for (i=0; i < sizeof(bufs)/sizeof(*bufs); i++)
{
if (bufs[i].BufferType == SECBUFFER_EXTRA)
{
g_ssl->extra_buf = xmalloc(bufs[i].cbBuffer);
if (!g_ssl->extra_buf)
return ERROR_NOT_ENOUGH_MEMORY;
g_ssl->extra_len = bufs[i].cbBuffer;
memcpy(g_ssl->extra_buf, bufs[i].pvBuffer, g_ssl->extra_len);
}
}
return ERROR_SUCCESS;
}
#endif /* WITH_SSL */
/* Send TCP transport data packet */
void
tcp_send(STREAM s)
{
#ifdef WITH_SSL
int ssl_err;
#endif /* WITH_SSL */
int length = s->end - s->data;
int sent, total = 0;
@ -105,26 +283,28 @@ tcp_send(STREAM s)
#ifdef WITH_SSL
if (g_ssl)
{
sent = SSL_write(g_ssl, s->data + total, length - total);
if (sent <= 0)
const BYTE *ptr = s->data + total;
size_t chunk_size;
sent = 0;
while (length - total)
{
ssl_err = SSL_get_error(g_ssl, sent);
if (sent < 0 && (ssl_err == SSL_ERROR_WANT_READ ||
ssl_err == SSL_ERROR_WANT_WRITE))
{
TCP_SLEEP(0);
sent = 0;
}
else
chunk_size = min(length - total, g_ssl->ssl_sizes.cbMaximumMessage);
if (!send_ssl_chunk(ptr, chunk_size))
{
#ifdef WITH_SCARD
scard_unlock(SCARD_LOCK_TCP);
#endif
error("SSL_write: %d (%s)\n", ssl_err, TCP_STRERROR);
//error("send_ssl_chunk: %d (%s)\n", sent, TCP_STRERROR);
g_network_error = True;
return;
}
sent += chunk_size;
ptr += chunk_size;
length -= chunk_size;
}
}
else
@ -144,7 +324,7 @@ tcp_send(STREAM s)
scard_unlock(SCARD_LOCK_TCP);
#endif
error("send: %s\n", TCP_STRERROR);
error("send: %d (%s)\n", sent, TCP_STRERROR);
g_network_error = True;
return;
}
@ -165,9 +345,6 @@ tcp_recv(STREAM s, uint32 length)
{
uint32 new_length, end_offset, p_offset;
int rcvd = 0;
#ifdef WITH_SSL
int ssl_err;
#endif /* WITH_SSL */
if (g_network_error == True)
return NULL;
@ -201,7 +378,7 @@ tcp_recv(STREAM s, uint32 length)
while (length > 0)
{
#ifdef WITH_SSL
if (!g_ssl || SSL_pending(g_ssl) <= 0)
if (!g_ssl)
#endif /* WITH_SSL */
{
if (!ui_select(g_sock))
@ -215,33 +392,50 @@ tcp_recv(STREAM s, uint32 length)
#ifdef WITH_SSL
if (g_ssl)
{
rcvd = SSL_read(g_ssl, s->end, length);
ssl_err = SSL_get_error(g_ssl, rcvd);
SIZE_T size = 0;
BOOL eof;
DWORD res;
if (ssl_err == SSL_ERROR_SSL)
if (g_ssl->peek_msg)
{
if (SSL_get_shutdown(g_ssl) & SSL_RECEIVED_SHUTDOWN)
size = min(length, g_ssl->peek_len);
memcpy(s->end, g_ssl->peek_msg, size);
g_ssl->peek_len -= size;
g_ssl->peek_msg += size;
s->end += size;
if (!g_ssl->peek_len)
{
error("Remote peer initiated ssl shutdown.\n");
return NULL;
xfree(g_ssl->peek_msg_mem);
g_ssl->peek_msg_mem = g_ssl->peek_msg = NULL;
}
ERR_print_errors_fp(stdout);
g_network_error = True;
return NULL;
return s;
}
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE)
do
{
rcvd = 0;
res = read_ssl_chunk((BYTE*)s->end, length, TRUE, &size, &eof);
if (res != ERROR_SUCCESS)
{
if (res == WSAEWOULDBLOCK)
{
if (size)
{
res = ERROR_SUCCESS;
}
}
else
{
error("read_ssl_chunk: %d (%s)\n", res, TCP_STRERROR);
g_network_error = True;
return NULL;
}
break;
}
}
else if (ssl_err != SSL_ERROR_NONE)
{
error("SSL_read: %d (%s)\n", ssl_err, TCP_STRERROR);
g_network_error = True;
return NULL;
}
while (!size && !eof);
rcvd = size;
}
else
{
@ -255,7 +449,7 @@ tcp_recv(STREAM s, uint32 length)
}
else
{
error("recv: %s\n", TCP_STRERROR);
error("recv: %d (%s)\n", rcvd, TCP_STRERROR);
g_network_error = True;
return NULL;
}
@ -277,78 +471,206 @@ tcp_recv(STREAM s, uint32 length)
}
#ifdef WITH_SSL
RD_BOOL
ensure_cred_handle(void)
{
SECURITY_STATUS res = SEC_E_OK;
if (!cred_handle_initialized)
{
SCHANNEL_CRED cred = {SCHANNEL_CRED_VERSION};
SecPkgCred_SupportedProtocols prots;
res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
NULL, NULL, &cred_handle, NULL);
if (res == SEC_E_OK)
{
res = QueryCredentialsAttributesA(&cred_handle, SECPKG_ATTR_SUPPORTED_PROTOCOLS, &prots);
if (res != SEC_E_OK || (prots.grbitProtocol & SP_PROT_TLS1_1PLUS_CLIENT))
{
cred.grbitEnabledProtocols = prots.grbitProtocol & ~SP_PROT_TLS1_1PLUS_CLIENT;
res = AcquireCredentialsHandleW(NULL, (WCHAR*)UNISP_NAME_W, SECPKG_CRED_OUTBOUND, NULL, &cred,
NULL, NULL, &compat_cred_handle, NULL);
have_compat_cred_handle = res == SEC_E_OK;
}
}
cred_handle_initialized = res == SEC_E_OK;
}
if (res != SEC_E_OK)
{
error("ensure_cred_handle failed: %ld\n", res);
return False;
}
return True;
}
DWORD
ssl_handshake(RD_BOOL compat_mode)
{
SecBuffer out_buf = {0, SECBUFFER_TOKEN, NULL}, in_bufs[2] = {{0, SECBUFFER_TOKEN}, {0, SECBUFFER_EMPTY}};
SecBufferDesc out_desc = {SECBUFFER_VERSION, 1, &out_buf}, in_desc = {SECBUFFER_VERSION, 2, in_bufs};
SecHandle *cred = &cred_handle;
BYTE *read_buf;
SIZE_T read_buf_size = 2048;
ULONG attrs = 0;
CtxtHandle ctx;
SSIZE_T size;
SECURITY_STATUS status;
DWORD res = ERROR_SUCCESS;
const DWORD isc_req_flags = ISC_REQ_ALLOCATE_MEMORY|ISC_REQ_USE_SESSION_KEY|ISC_REQ_CONFIDENTIALITY
|ISC_REQ_SEQUENCE_DETECT|ISC_REQ_REPLAY_DETECT|ISC_REQ_MANUAL_CRED_VALIDATION;
if (!ensure_cred_handle())
return -1;
if (compat_mode) {
if (!have_compat_cred_handle)
return -1;
cred = &compat_cred_handle;
}
read_buf = xmalloc(read_buf_size);
if (!read_buf)
return ERROR_OUTOFMEMORY;
if (!g_ssl_server)
return -1;
status = InitializeSecurityContextA(cred, NULL, g_ssl_server, isc_req_flags, 0, 0, NULL, 0,
&ctx, &out_desc, &attrs, NULL);
//assert(status != SEC_E_OK);
while (status == SEC_I_CONTINUE_NEEDED || status == SEC_E_INCOMPLETE_MESSAGE)
{
if (out_buf.cbBuffer)
{
//assert(status == SEC_I_CONTINUE_NEEDED);
size = send(g_sock, out_buf.pvBuffer, out_buf.cbBuffer, 0);
if (size != out_buf.cbBuffer)
{
error("send failed: %d (%s)\n", size, TCP_STRERROR);
status = -1;
break;
}
FreeContextBuffer(out_buf.pvBuffer);
out_buf.pvBuffer = NULL;
out_buf.cbBuffer = 0;
}
if (status == SEC_I_CONTINUE_NEEDED)
{
//assert(in_bufs[1].cbBuffer < read_buf_size);
memmove(read_buf, (BYTE*)in_bufs[0].pvBuffer+in_bufs[0].cbBuffer-in_bufs[1].cbBuffer, in_bufs[1].cbBuffer);
in_bufs[0].cbBuffer = in_bufs[1].cbBuffer;
in_bufs[1].BufferType = SECBUFFER_EMPTY;
in_bufs[1].cbBuffer = 0;
in_bufs[1].pvBuffer = NULL;
}
//assert(in_bufs[0].BufferType == SECBUFFER_TOKEN);
//assert(in_bufs[1].BufferType == SECBUFFER_EMPTY);
if (in_bufs[0].cbBuffer + 1024 > read_buf_size)
{
BYTE *new_read_buf = xrealloc(read_buf, read_buf_size + 1024);
if (!new_read_buf)
{
status = E_OUTOFMEMORY;
break;
}
in_bufs[0].pvBuffer = read_buf = new_read_buf;
read_buf_size += 1024;
}
size = recv(g_sock, read_buf+in_bufs[0].cbBuffer, read_buf_size-in_bufs[0].cbBuffer, 0);
if (size < 1)
{
error("recv failed: %d (%s)\n", size, TCP_STRERROR);
res = -1;
break;
}
in_bufs[0].cbBuffer += size;
in_bufs[0].pvBuffer = read_buf;
status = InitializeSecurityContextA(cred, &ctx, g_ssl_server, isc_req_flags, 0, 0, &in_desc,
0, NULL, &out_desc, &attrs, NULL);
if (status == SEC_E_OK) {
if (SecIsValidHandle(&g_ssl->ssl_ctx))
DeleteSecurityContext(&g_ssl->ssl_ctx);
g_ssl->ssl_ctx = ctx;
if (in_bufs[1].BufferType == SECBUFFER_EXTRA)
{
//FIXME("SECBUFFER_EXTRA not supported\n");
}
status = QueryContextAttributesW(&ctx, SECPKG_ATTR_STREAM_SIZES, &g_ssl->ssl_sizes);
if (status != SEC_E_OK)
{
//error("Can't determine ssl buffer sizes: %ld\n", status);
break;
}
g_ssl->ssl_buf = xmalloc(g_ssl->ssl_sizes.cbHeader + g_ssl->ssl_sizes.cbMaximumMessage
+ g_ssl->ssl_sizes.cbTrailer);
if (!g_ssl->ssl_buf)
{
res = GetLastError();
break;
}
}
}
xfree(read_buf);
if (status != SEC_E_OK || res != ERROR_SUCCESS)
{
error("Failed to establish SSL connection: %08x (%u)\n", status, res);
xfree(g_ssl->ssl_buf);
g_ssl->ssl_buf = NULL;
return res ? res : -1;
}
return ERROR_SUCCESS;
}
/* Establish a SSL/TLS 1.0 connection */
RD_BOOL
tcp_tls_connect(void)
{
int err;
long options;
char tcp_port_rdp_s[10];
if (!g_ssl_initialized)
{
SSL_load_error_strings();
SSL_library_init();
g_ssl = &g_ssl1;
SecInvalidateHandle(&g_ssl->ssl_ctx);
g_ssl_initialized = True;
}
/* create process context */
if (g_ssl_ctx == NULL)
snprintf(tcp_port_rdp_s, 10, "%d", g_tcp_port_rdp);
if ((err = ssl_handshake(FALSE)) != 0)
{
g_ssl_ctx = SSL_CTX_new(TLSv1_client_method());
if (g_ssl_ctx == NULL)
{
error("tcp_tls_connect: SSL_CTX_new() failed to create TLS v1.0 context\n");
goto fail;
}
options = 0;
#ifdef SSL_OP_NO_COMPRESSION
options |= SSL_OP_NO_COMPRESSION;
#endif // __SSL_OP_NO_COMPRESSION
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
SSL_CTX_set_options(g_ssl_ctx, options);
}
/* free old connection */
if (g_ssl)
SSL_free(g_ssl);
/* create new ssl connection */
g_ssl = SSL_new(g_ssl_ctx);
if (g_ssl == NULL)
{
error("tcp_tls_connect: SSL_new() failed\n");
goto fail;
}
if (SSL_set_fd(g_ssl, g_sock) < 1)
{
error("tcp_tls_connect: SSL_set_fd() failed\n");
goto fail;
}
do
{
err = SSL_connect(g_ssl);
}
while (SSL_get_error(g_ssl, err) == SSL_ERROR_WANT_READ);
if (err < 0)
{
ERR_print_errors_fp(stdout);
goto fail;
}
return True;
fail:
if (g_ssl)
SSL_free(g_ssl);
if (g_ssl_ctx)
SSL_CTX_free(g_ssl_ctx);
g_ssl = NULL;
g_ssl_ctx = NULL;
return False;
}
@ -356,46 +678,36 @@ tcp_tls_connect(void)
RD_BOOL
tcp_tls_get_server_pubkey(STREAM s)
{
X509 *cert = NULL;
EVP_PKEY *pkey = NULL;
const CERT_CONTEXT *cert = NULL;
SECURITY_STATUS status;
s->data = s->p = NULL;
s->size = 0;
if (g_ssl == NULL)
goto out;
cert = SSL_get_peer_certificate(g_ssl);
if (cert == NULL)
status = QueryContextAttributesW(&g_ssl->ssl_ctx, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (void*)&cert);
if (status != SEC_E_OK)
{
error("tcp_tls_get_server_pubkey: SSL_get_peer_certificate() failed\n");
error("tcp_tls_get_server_pubkey: QueryContextAttributesW() failed %ld\n", status);
goto out;
}
pkey = X509_get_pubkey(cert);
if (pkey == NULL)
{
error("tcp_tls_get_server_pubkey: X509_get_pubkey() failed\n");
goto out;
}
s->size = i2d_PublicKey(pkey, NULL);
s->size = cert->cbCertEncoded;
if (s->size < 1)
{
error("tcp_tls_get_server_pubkey: i2d_PublicKey() failed\n");
error("tcp_tls_get_server_pubkey: cert->cbCertEncoded = %ld\n", cert->cbCertEncoded);
goto out;
}
s->data = s->p = xmalloc(s->size);
i2d_PublicKey(pkey, &s->p);
s->data = s->p = (unsigned char *)xmalloc(s->size);
memcpy(cert->pbCertEncoded, &s->p, s->size);
s->p = s->data;
s->end = s->p + s->size;
out:
if (cert)
X509_free(cert);
if (pkey)
EVP_PKEY_free(pkey);
CertFreeCertificateContext(cert);
return (s->size != 0);
}
#endif /* WITH_SSL */
@ -508,6 +820,10 @@ tcp_connect(char *server)
g_out[i].data = (uint8 *) xmalloc(g_out[i].size);
}
#ifdef WITH_SSL
g_ssl_server = xmalloc(strlen(server)+1);
#endif /* WITH_SSL */
return True;
}
@ -518,12 +834,28 @@ tcp_disconnect(void)
#ifdef WITH_SSL
if (g_ssl)
{
if (!g_network_error)
(void) SSL_shutdown(g_ssl);
SSL_free(g_ssl);
xfree(g_ssl->peek_msg_mem);
g_ssl->peek_msg_mem = NULL;
g_ssl->peek_msg = NULL;
g_ssl->peek_len = 0;
xfree(g_ssl->ssl_buf);
g_ssl->ssl_buf = NULL;
xfree(g_ssl->extra_buf);
g_ssl->extra_buf = NULL;
g_ssl->extra_len = 0;
if (SecIsValidHandle(&g_ssl->ssl_ctx))
DeleteSecurityContext(&g_ssl->ssl_ctx);
if (cred_handle_initialized)
FreeCredentialsHandle(&cred_handle);
if (have_compat_cred_handle)
FreeCredentialsHandle(&compat_cred_handle);
if (g_ssl_server)
{
xfree(g_ssl_server);
g_ssl_server = NULL;
}
g_ssl = NULL;
SSL_CTX_free(g_ssl_ctx);
g_ssl_ctx = NULL;
g_ssl_initialized = False;
}
#endif /* WITH_SSL */
TCP_CLOSE(g_sock);