libsec: implement tlsClient support for RFC6066 server name identification (SNI)

tlsClient() now can optionally send the server_name in the ClientHello
message by setting the TLSconn.serverName. This is required for some
https sites.
This commit is contained in:
cinap_lenrek 2015-05-21 02:26:57 +02:00
parent a1bbf39c34
commit 40360a992d
3 changed files with 93 additions and 15 deletions

View file

@ -383,6 +383,7 @@ typedef struct TLSconn{
uchar *sessionKey; uchar *sessionKey;
int sessionKeylen; int sessionKeylen;
char *sessionConst; char *sessionConst;
char *serverName;
} TLSconn; } TLSconn;
/* tlshand.c */ /* tlshand.c */

View file

@ -107,6 +107,7 @@ typedef struct TLSconn {
uchar *sessionKey; /* opt IN/OUT session key */ uchar *sessionKey; /* opt IN/OUT session key */
int sessionKeylen; /* opt IN session key length */ int sessionKeylen; /* opt IN session key length */
char *sessionConst; /* opt IN session constant */ char *sessionConst; /* opt IN session constant */
char *serverName; /* opt IN server name */
} TLSconn; } TLSconn;
.EE .EE
.PP .PP

View file

@ -96,13 +96,15 @@ typedef struct Msg{
Bytes* sid; Bytes* sid;
Ints* ciphers; Ints* ciphers;
Bytes* compressors; Bytes* compressors;
Bytes* extensions;
} clientHello; } clientHello;
struct { struct {
int version; int version;
uchar random[RandomSize]; uchar random[RandomSize];
Bytes* sid; Bytes* sid;
int cipher; int cipher;
int compressor; int compressor;
Bytes* extensions;
} serverHello; } serverHello;
struct { struct {
int ncert; int ncert;
@ -266,8 +268,8 @@ static uchar compressors[] = {
CompressionNull, CompressionNull,
}; };
static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain); static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...)); static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
static void msgClear(Msg *m); static void msgClear(Msg *m);
static char* msgPrint(char *buf, int n, Msg *m); static char* msgPrint(char *buf, int n, Msg *m);
static int msgRecv(TlsConnection *c, Msg *m); static int msgRecv(TlsConnection *c, Msg *m);
@ -390,6 +392,33 @@ tlsServer(int fd, TLSconn *conn)
return data; return data;
} }
static uchar*
tlsClientExtensions(TLSconn *conn, int *plen)
{
uchar *b, *p;
int n, m;
p = b = nil;
// RFC6066 - Server Name Identification
if(conn->serverName != nil){
n = strlen(conn->serverName);
m = p - b;
b = erealloc(b, m+2+2+2+1+2+n);
p = b + m;
put16(p, 0), p += 2; /* Type: server_name */
put16(p, 2+1+2+n), p += 2; /* Length */
put16(p, 1+2+n), p += 2; /* Server Name list length */
*p++ = 0; /* Server Name Type: host_name */
put16(p, n), p += 2; /* Server Name length */
memmove(p, conn->serverName, n);
p += n;
}
*plen = p - b;
return b;
}
// push TLS onto fd, returning new (application) file descriptor // push TLS onto fd, returning new (application) file descriptor
// or -1 if error. // or -1 if error.
int int
@ -399,6 +428,7 @@ tlsClient(int fd, TLSconn *conn)
char dname[64]; char dname[64];
int n, data, ctl, hand; int n, data, ctl, hand;
TlsConnection *tls; TlsConnection *tls;
uchar *ext;
if(conn == nil) if(conn == nil)
return -1; return -1;
@ -426,7 +456,10 @@ tlsClient(int fd, TLSconn *conn)
return -1; return -1;
} }
fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, conn->trace); ext = tlsClientExtensions(conn, &n);
tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen,
ext, n, conn->trace);
free(ext);
close(hand); close(hand);
close(ctl); close(ctl);
if(tls == nil){ if(tls == nil){
@ -466,7 +499,7 @@ countchain(PEMChain *p)
} }
static TlsConnection * static TlsConnection *
tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chp) tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
{ {
TlsConnection *c; TlsConnection *c;
Msg m; Msg m;
@ -531,12 +564,12 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...
tlsError(c, EHandshakeFailure, "can't initialize security: %r"); tlsError(c, EHandshakeFailure, "can't initialize security: %r");
goto Err; goto Err;
} }
c->sec->rpc = factotum_rsa_open(cert, ncert); c->sec->rpc = factotum_rsa_open(cert, certlen);
if(c->sec->rpc == nil){ if(c->sec->rpc == nil){
tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
goto Err; goto Err;
} }
c->sec->rsapub = X509toRSApub(cert, ncert, nil, 0); c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
if(c->sec->rsapub == nil){ if(c->sec->rsapub == nil){
tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate"); tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
goto Err; goto Err;
@ -558,7 +591,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...
numcerts = countchain(chp); numcerts = countchain(chp);
m.u.certificate.ncert = 1 + numcerts; m.u.certificate.ncert = 1 + numcerts;
m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*)); m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
m.u.certificate.certs[0] = makebytes(cert, ncert); m.u.certificate.certs[0] = makebytes(cert, certlen);
for (i = 0; i < numcerts && chp; i++, chp = chp->next) for (i = 0; i < numcerts && chp; i++, chp = chp->next)
m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen); m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
if(!msgSend(c, &m, AQueue)) if(!msgSend(c, &m, AQueue))
@ -702,7 +735,8 @@ Out:
} }
static TlsConnection * static TlsConnection *
tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...)) tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
int (*trace)(char*fmt, ...))
{ {
TlsConnection *c; TlsConnection *c;
Msg m; Msg m;
@ -735,6 +769,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
m.u.clientHello.sid = makebytes(csid, ncsid); m.u.clientHello.sid = makebytes(csid, ncsid);
m.u.clientHello.ciphers = makeciphers(); m.u.clientHello.ciphers = makeciphers();
m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors)); m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
m.u.clientHello.extensions = makebytes(ext, extlen);
if(!msgSend(c, &m, AFlush)) if(!msgSend(c, &m, AFlush))
goto Err; goto Err;
msgClear(&m); msgClear(&m);
@ -1015,6 +1050,15 @@ msgSend(TlsConnection *c, Msg *m, int act)
p[0] = n; p[0] = n;
memmove(p+1, m->u.clientHello.compressors->data, n); memmove(p+1, m->u.clientHello.compressors->data, n);
p += n+1; p += n+1;
if(m->u.clientHello.extensions == nil)
break;
n = m->u.clientHello.extensions->len;
if(n == 0)
break;
put16(p, n);
memmove(p+2, m->u.clientHello.extensions->data, n);
p += n+2;
break; break;
case HServerHello: case HServerHello:
put16(p, m->u.serverHello.version); put16(p, m->u.serverHello.version);
@ -1035,6 +1079,15 @@ msgSend(TlsConnection *c, Msg *m, int act)
p += 2; p += 2;
p[0] = m->u.serverHello.compressor; p[0] = m->u.serverHello.compressor;
p += 1; p += 1;
if(m->u.serverHello.extensions == nil)
break;
n = m->u.serverHello.extensions->len;
if(n == 0)
break;
put16(p, n);
memmove(p+2, m->u.serverHello.extensions->data, n);
p += n+2;
break; break;
case HServerHelloDone: case HServerHelloDone:
break; break;
@ -1249,9 +1302,17 @@ msgRecv(TlsConnection *c, Msg *m)
if(n < 1 || n < p[0]+1 || p[0] == 0) if(n < 1 || n < p[0]+1 || p[0] == 0)
goto Short; goto Short;
nn = p[0]; nn = p[0];
m->u.clientHello.compressors = newbytes(nn); m->u.clientHello.compressors = makebytes(p+1, nn);
memmove(m->u.clientHello.compressors->data, p+1, nn); p += nn + 1;
n -= nn + 1; n -= nn + 1;
if(n < 2)
break;
nn = get16(p);
if(nn > n-2)
goto Short;
m->u.clientHello.extensions = makebytes(p+2, nn);
n -= nn + 2;
break; break;
case HServerHello: case HServerHello:
if(n < 2) if(n < 2)
@ -1276,7 +1337,16 @@ msgRecv(TlsConnection *c, Msg *m)
goto Short; goto Short;
m->u.serverHello.cipher = get16(p); m->u.serverHello.cipher = get16(p);
m->u.serverHello.compressor = p[2]; m->u.serverHello.compressor = p[2];
p += 3;
n -= 3; n -= 3;
if(n < 2)
break;
nn = get16(p);
if(nn > n-2)
goto Short;
m->u.serverHello.extensions = makebytes(p+2, nn);
n -= nn + 2;
break; break;
case HCertificate: case HCertificate:
if(n < 3) if(n < 3)
@ -1409,7 +1479,7 @@ msgRecv(TlsConnection *c, Msg *m)
break; break;
} }
if(type != HClientHello && n != 0) if(type != HClientHello && type != HServerHello && n != 0)
goto Short; goto Short;
Ok: Ok:
if(c->trace){ if(c->trace){
@ -1440,9 +1510,11 @@ msgClear(Msg *m)
freebytes(m->u.clientHello.sid); freebytes(m->u.clientHello.sid);
freeints(m->u.clientHello.ciphers); freeints(m->u.clientHello.ciphers);
freebytes(m->u.clientHello.compressors); freebytes(m->u.clientHello.compressors);
freebytes(m->u.clientHello.extensions);
break; break;
case HServerHello: case HServerHello:
freebytes(m->u.clientHello.sid); freebytes(m->u.serverHello.sid);
freebytes(m->u.serverHello.extensions);
break; break;
case HCertificate: case HCertificate:
for(i=0; i<m->u.certificate.ncert; i++) for(i=0; i<m->u.certificate.ncert; i++)
@ -1534,6 +1606,8 @@ msgPrint(char *buf, int n, Msg *m)
bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n"); bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n"); bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n"); bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
if(m->u.clientHello.extensions != nil)
bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
break; break;
case HServerHello: case HServerHello:
bs = seprint(bs, be, "ServerHello\n"); bs = seprint(bs, be, "ServerHello\n");
@ -1545,6 +1619,8 @@ msgPrint(char *buf, int n, Msg *m)
bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n"); bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher); bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor); bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
if(m->u.serverHello.extensions != nil)
bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
break; break;
case HCertificate: case HCertificate:
bs = seprint(bs, be, "Certificate\n"); bs = seprint(bs, be, "Certificate\n");