From 40360a992d03ccccf69a36fa20359ad029b3afcf Mon Sep 17 00:00:00 2001 From: cinap_lenrek Date: Thu, 21 May 2015 02:26:57 +0200 Subject: [PATCH] libsec: implement tlsClient support for RFC6066 server name identification (SNI) tlsClient() now can optionally send the server_name in the ClientHello message by setting the TLSconn.serverName. This is required for some https sites. --- sys/include/libsec.h | 1 + sys/man/2/pushtls | 1 + sys/src/libsec/port/tlshand.c | 106 +++++++++++++++++++++++++++++----- 3 files changed, 93 insertions(+), 15 deletions(-) diff --git a/sys/include/libsec.h b/sys/include/libsec.h index ccebac087..9a32aa554 100644 --- a/sys/include/libsec.h +++ b/sys/include/libsec.h @@ -383,6 +383,7 @@ typedef struct TLSconn{ uchar *sessionKey; int sessionKeylen; char *sessionConst; + char *serverName; } TLSconn; /* tlshand.c */ diff --git a/sys/man/2/pushtls b/sys/man/2/pushtls index cf9ad21ff..dfa01e4dd 100644 --- a/sys/man/2/pushtls +++ b/sys/man/2/pushtls @@ -107,6 +107,7 @@ typedef struct TLSconn { uchar *sessionKey; /* opt IN/OUT session key */ int sessionKeylen; /* opt IN session key length */ char *sessionConst; /* opt IN session constant */ + char *serverName; /* opt IN server name */ } TLSconn; .EE .PP diff --git a/sys/src/libsec/port/tlshand.c b/sys/src/libsec/port/tlshand.c index d8df85d9a..a0655c790 100644 --- a/sys/src/libsec/port/tlshand.c +++ b/sys/src/libsec/port/tlshand.c @@ -96,13 +96,15 @@ typedef struct Msg{ Bytes* sid; Ints* ciphers; Bytes* compressors; + Bytes* extensions; } clientHello; struct { int version; - uchar random[RandomSize]; + uchar random[RandomSize]; Bytes* sid; - int cipher; - int compressor; + int cipher; + int compressor; + Bytes* extensions; } serverHello; struct { int ncert; @@ -266,8 +268,8 @@ static uchar compressors[] = { CompressionNull, }; -static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain); -static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...)); +static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain); +static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...)); static void msgClear(Msg *m); static char* msgPrint(char *buf, int n, Msg *m); static int msgRecv(TlsConnection *c, Msg *m); @@ -390,6 +392,33 @@ tlsServer(int fd, TLSconn *conn) return data; } +static uchar* +tlsClientExtensions(TLSconn *conn, int *plen) +{ + uchar *b, *p; + int n, m; + + p = b = nil; + + // RFC6066 - Server Name Identification + if(conn->serverName != nil){ + n = strlen(conn->serverName); + m = p - b; + b = erealloc(b, m+2+2+2+1+2+n); + p = b + m; + put16(p, 0), p += 2; /* Type: server_name */ + put16(p, 2+1+2+n), p += 2; /* Length */ + put16(p, 1+2+n), p += 2; /* Server Name list length */ + *p++ = 0; /* Server Name Type: host_name */ + put16(p, n), p += 2; /* Server Name length */ + memmove(p, conn->serverName, n); + p += n; + } + + *plen = p - b; + return b; +} + // push TLS onto fd, returning new (application) file descriptor // or -1 if error. int @@ -399,6 +428,7 @@ tlsClient(int fd, TLSconn *conn) char dname[64]; int n, data, ctl, hand; TlsConnection *tls; + uchar *ext; if(conn == nil) return -1; @@ -426,7 +456,10 @@ tlsClient(int fd, TLSconn *conn) return -1; } fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); - tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, conn->trace); + ext = tlsClientExtensions(conn, &n); + tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, + ext, n, conn->trace); + free(ext); close(hand); close(ctl); if(tls == nil){ @@ -466,7 +499,7 @@ countchain(PEMChain *p) } static TlsConnection * -tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp) +tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp) { TlsConnection *c; Msg m; @@ -531,12 +564,12 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ... tlsError(c, EHandshakeFailure, "can't initialize security: %r"); goto Err; } - c->sec->rpc = factotum_rsa_open(cert, ncert); + c->sec->rpc = factotum_rsa_open(cert, certlen); if(c->sec->rpc == nil){ tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); goto Err; } - c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0); + c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0); if(c->sec->rsapub == nil){ tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate"); goto Err; @@ -558,7 +591,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ... numcerts = countchain(chp); m.u.certificate.ncert = 1 + numcerts; m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*)); - m.u.certificate.certs[0] = makebytes(cert, ncert); + m.u.certificate.certs[0] = makebytes(cert, certlen); for (i = 0; i < numcerts && chp; i++, chp = chp->next) m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen); if(!msgSend(c, &m, AQueue)) @@ -702,7 +735,8 @@ Out: } static TlsConnection * -tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...)) +tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, + int (*trace)(char*fmt, ...)) { TlsConnection *c; Msg m; @@ -735,6 +769,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, m.u.clientHello.sid = makebytes(csid, ncsid); m.u.clientHello.ciphers = makeciphers(); m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors)); + m.u.clientHello.extensions = makebytes(ext, extlen); if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); @@ -1015,6 +1050,15 @@ msgSend(TlsConnection *c, Msg *m, int act) p[0] = n; memmove(p+1, m->u.clientHello.compressors->data, n); p += n+1; + + if(m->u.clientHello.extensions == nil) + break; + n = m->u.clientHello.extensions->len; + if(n == 0) + break; + put16(p, n); + memmove(p+2, m->u.clientHello.extensions->data, n); + p += n+2; break; case HServerHello: put16(p, m->u.serverHello.version); @@ -1035,6 +1079,15 @@ msgSend(TlsConnection *c, Msg *m, int act) p += 2; p[0] = m->u.serverHello.compressor; p += 1; + + if(m->u.serverHello.extensions == nil) + break; + n = m->u.serverHello.extensions->len; + if(n == 0) + break; + put16(p, n); + memmove(p+2, m->u.serverHello.extensions->data, n); + p += n+2; break; case HServerHelloDone: break; @@ -1249,9 +1302,17 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 1 || n < p[0]+1 || p[0] == 0) goto Short; nn = p[0]; - m->u.clientHello.compressors = newbytes(nn); - memmove(m->u.clientHello.compressors->data, p+1, nn); + m->u.clientHello.compressors = makebytes(p+1, nn); + p += nn + 1; n -= nn + 1; + + if(n < 2) + break; + nn = get16(p); + if(nn > n-2) + goto Short; + m->u.clientHello.extensions = makebytes(p+2, nn); + n -= nn + 2; break; case HServerHello: if(n < 2) @@ -1276,7 +1337,16 @@ msgRecv(TlsConnection *c, Msg *m) goto Short; m->u.serverHello.cipher = get16(p); m->u.serverHello.compressor = p[2]; + p += 3; n -= 3; + + if(n < 2) + break; + nn = get16(p); + if(nn > n-2) + goto Short; + m->u.serverHello.extensions = makebytes(p+2, nn); + n -= nn + 2; break; case HCertificate: if(n < 3) @@ -1409,7 +1479,7 @@ msgRecv(TlsConnection *c, Msg *m) break; } - if(type != HClientHello && n != 0) + if(type != HClientHello && type != HServerHello && n != 0) goto Short; Ok: if(c->trace){ @@ -1440,9 +1510,11 @@ msgClear(Msg *m) freebytes(m->u.clientHello.sid); freeints(m->u.clientHello.ciphers); freebytes(m->u.clientHello.compressors); + freebytes(m->u.clientHello.extensions); break; case HServerHello: - freebytes(m->u.clientHello.sid); + freebytes(m->u.serverHello.sid); + freebytes(m->u.serverHello.extensions); break; case HCertificate: for(i=0; iu.certificate.ncert; i++) @@ -1534,6 +1606,8 @@ msgPrint(char *buf, int n, Msg *m) bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n"); bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n"); bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n"); + if(m->u.clientHello.extensions != nil) + bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n"); break; case HServerHello: bs = seprint(bs, be, "ServerHello\n"); @@ -1545,6 +1619,8 @@ msgPrint(char *buf, int n, Msg *m) bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n"); bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher); bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor); + if(m->u.serverHello.extensions != nil) + bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n"); break; case HCertificate: bs = seprint(bs, be, "Certificate\n");