libsec: implement TLS-PSK for tlsClient()/tlsServer()

This commit is contained in:
cinap_lenrek 2015-12-25 17:05:05 +01:00
parent 4a6ab355c1
commit 39f18c9d88
3 changed files with 301 additions and 98 deletions

View file

@ -412,8 +412,10 @@ typedef struct TLSconn{
char dir[40]; /* connection directory */ char dir[40]; /* connection directory */
uchar *cert; /* certificate (local on input, remote on output) */ uchar *cert; /* certificate (local on input, remote on output) */
uchar *sessionID; uchar *sessionID;
uchar *psk;
int certlen; int certlen;
int sessionIDlen; int sessionIDlen;
int psklen;
int (*trace)(char*fmt, ...); int (*trace)(char*fmt, ...);
PEMChain*chain; /* optional extra certificate evidence for servers to present */ PEMChain*chain; /* optional extra certificate evidence for servers to present */
char *sessionType; char *sessionType;
@ -421,6 +423,7 @@ typedef struct TLSconn{
int sessionKeylen; int sessionKeylen;
char *sessionConst; char *sessionConst;
char *serverName; char *serverName;
char *pskID;
} TLSconn; } TLSconn;
/* tlshand.c */ /* tlshand.c */

View file

@ -100,7 +100,8 @@ typedef struct TLSconn {
char dir[40]; /* OUT connection directory */ char dir[40]; /* OUT connection directory */
uchar *cert; /* IN/OUT certificate */ uchar *cert; /* IN/OUT certificate */
uchar *sessionID; /* IN/OUT session ID */ uchar *sessionID; /* IN/OUT session ID */
int certlen, sessionIDlen; uchar *psk; /* opt IN pre-shared key */
int certlen, sessionIDlen, psklen;
int (*trace)(char*fmt, ...); int (*trace)(char*fmt, ...);
PEMChain *chain; PEMChain *chain;
char *sessionType; /* opt IN session type */ char *sessionType; /* opt IN session type */
@ -108,6 +109,7 @@ typedef struct TLSconn {
int sessionKeylen; /* opt IN session key length */ int sessionKeylen; /* opt IN session key length */
char *sessionConst; /* opt IN session constant */ char *sessionConst; /* opt IN session constant */
char *serverName; /* opt IN server name */ char *serverName; /* opt IN server name */
char *pskID; /* opt IN pre-shared key ID */
} TLSconn; } TLSconn;
.EE .EE
.PP .PP

View file

@ -135,9 +135,11 @@ typedef struct Msg{
Bytes **cas; Bytes **cas;
} certificateRequest; } certificateRequest;
struct { struct {
Bytes *pskid;
Bytes *key; Bytes *key;
} clientKeyExchange; } clientKeyExchange;
struct { struct {
Bytes *pskid;
Bytes *dh_p; Bytes *dh_p;
Bytes *dh_g; Bytes *dh_g;
Bytes *dh_Ys; Bytes *dh_Ys;
@ -159,6 +161,8 @@ typedef struct TlsSec{
int ok; // <0 killed; == 0 in progress; >0 reusable int ok; // <0 killed; == 0 in progress; >0 reusable
RSApub *rsapub; RSApub *rsapub;
AuthRpc *rpc; // factotum for rsa private key AuthRpc *rpc; // factotum for rsa private key
uchar *psk; // pre-shared key
int psklen;
uchar sec[MasterSecretSize]; // master secret uchar sec[MasterSecretSize]; // master secret
uchar crandom[RandomSize]; // client random uchar crandom[RandomSize]; // client random
uchar srandom[RandomSize]; // server random uchar srandom[RandomSize]; // server random
@ -223,6 +227,7 @@ enum {
EInternalError = 80, EInternalError = 80,
EUserCanceled = 90, EUserCanceled = 90,
ENoRenegotiation = 100, ENoRenegotiation = 100,
EUnknownPSKidentity = 115,
EMax = 256 EMax = 256
}; };
@ -274,6 +279,16 @@ enum {
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0XC013, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0XC013,
TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0XC014, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0XC014,
TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027,
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = 0xCCA8,
TLS_DHE_RSA_WITH_CHACHA20_POLY1305 = 0xCCAA,
GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305 = 0xCC13,
GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305 = 0xCC15,
TLS_PSK_WITH_CHACHA20_POLY1305 = 0xCCAB,
TLS_PSK_WITH_AES_128_CBC_SHA256 = 0x00AE,
TLS_PSK_WITH_AES_128_CBC_SHA = 0x008C,
}; };
// compression methods // compression methods
@ -283,10 +298,12 @@ enum {
}; };
static Algs cipherAlgs[] = { static Algs cipherAlgs[] = {
{"ccpoly96_aead", "clear", 2*(32+12), 0xCCA8}, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF) {"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305},
{"ccpoly96_aead", "clear", 2*(32+12), 0xCCAA}, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF) {"ccpoly96_aead", "clear", 2*(32+12), TLS_DHE_RSA_WITH_CHACHA20_POLY1305},
{"ccpoly64_aead", "clear", 2*32, 0xCC13}, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft)
{"ccpoly64_aead", "clear", 2*32, 0xCC15}, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft) {"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305},
{"ccpoly64_aead", "clear", 2*32, GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305},
{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256}, {"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256},
{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
@ -299,6 +316,11 @@ static Algs cipherAlgs[] = {
{"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}, {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
{"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA}, {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
{"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA}, {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
// PSK cipher suits
{"ccpoly96_aead", "clear", 2*(32+12), TLS_PSK_WITH_CHACHA20_POLY1305},
{"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_PSK_WITH_AES_128_CBC_SHA256},
{"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_PSK_WITH_AES_128_CBC_SHA},
}; };
static uchar compressors[] = { static uchar compressors[] = {
@ -327,8 +349,15 @@ static int sigalgs[] = {
0x0201, /* SHA1 RSA */ 0x0201, /* SHA1 RSA */
}; };
static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain); static TlsConnection *tlsServer2(int ctl, int hand,
static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...)); uchar *cert, int certlen,
char *pskid, uchar *psk, int psklen,
int (*trace)(char*fmt, ...), PEMChain *chain);
static TlsConnection *tlsClient2(int ctl, int hand,
uchar *csid, int ncsid,
uchar *cert, int certlen,
char *pskid, uchar *psk, int psklen,
uchar *ext, int extlen, int (*trace)(char*fmt, ...));
static void msgClear(Msg *m); static void msgClear(Msg *m);
static char* msgPrint(char *buf, int n, Msg *m); static char* msgPrint(char *buf, int n, Msg *m);
static int msgRecv(TlsConnection *c, Msg *m); static int msgRecv(TlsConnection *c, Msg *m);
@ -340,15 +369,17 @@ static int finishedMatch(TlsConnection *c, Finished *f);
static void tlsConnectionFree(TlsConnection *c); static void tlsConnectionFree(TlsConnection *c);
static int setAlgs(TlsConnection *c, int a); static int setAlgs(TlsConnection *c, int a);
static int okCipher(Ints *cv); static int okCipher(Ints *cv, int ispsk);
static int okCompression(Bytes *cv); static int okCompression(Bytes *cv);
static int initCiphers(void); static int initCiphers(void);
static Ints* makeciphers(void); static Ints* makeciphers(int ispsk);
static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom); static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
static int tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm); static int tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
static int tlsSecPSKs(TlsSec *sec, int vers);
static TlsSec* tlsSecInitc(int cvers, uchar *crandom); static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
static Bytes* tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers); static Bytes* tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
static int tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers);
static Bytes* tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys); static Bytes* tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
static Bytes* tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys); static Bytes* tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
static int tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient); static int tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
@ -424,7 +455,10 @@ tlsServer(int fd, TLSconn *conn)
return -1; return -1;
} }
fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain); tls = tlsServer2(ctl, hand,
conn->cert, conn->certlen,
conn->pskID, conn->psk, conn->psklen,
conn->trace, conn->chain);
snprint(dname, sizeof(dname), "#a/tls/%s/data", buf); snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
data = open(dname, ORDWR); data = open(dname, ORDWR);
close(hand); close(hand);
@ -435,7 +469,7 @@ tlsServer(int fd, TLSconn *conn)
return -1; return -1;
} }
free(conn->cert); free(conn->cert);
conn->cert = 0; // client certificates are not yet implemented conn->cert = nil; // client certificates are not yet implemented
conn->certlen = 0; conn->certlen = 0;
conn->sessionIDlen = tls->sid->len; conn->sessionIDlen = tls->sid->len;
conn->sessionID = emalloc(conn->sessionIDlen); conn->sessionID = emalloc(conn->sessionIDlen);
@ -561,7 +595,10 @@ tlsClient(int fd, TLSconn *conn)
} }
fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
ext = tlsClientExtensions(conn, &n); ext = tlsClientExtensions(conn, &n);
tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, tls = tlsClient2(ctl, hand,
conn->sessionID, conn->sessionIDlen,
conn->cert, conn->certlen,
conn->pskID, conn->psk, conn->psklen,
ext, n, conn->trace); ext, n, conn->trace);
free(ext); free(ext);
close(hand); close(hand);
@ -570,9 +607,14 @@ tlsClient(int fd, TLSconn *conn)
close(data); close(data);
return -1; return -1;
} }
conn->certlen = tls->cert->len; if(tls->cert != nil){
conn->cert = emalloc(conn->certlen); conn->certlen = tls->cert->len;
memcpy(conn->cert, tls->cert->data, conn->certlen); conn->cert = emalloc(conn->certlen);
memcpy(conn->cert, tls->cert->data, conn->certlen);
} else {
conn->certlen = 0;
conn->cert = nil;
}
conn->sessionIDlen = tls->sid->len; conn->sessionIDlen = tls->sid->len;
conn->sessionID = emalloc(conn->sessionIDlen); conn->sessionID = emalloc(conn->sessionIDlen);
memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen); memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
@ -603,7 +645,10 @@ countchain(PEMChain *p)
} }
static TlsConnection * static TlsConnection *
tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp) tlsServer2(int ctl, int hand,
uchar *cert, int certlen,
char *pskid, uchar *psk, int psklen,
int (*trace)(char*fmt, ...), PEMChain *chp)
{ {
TlsConnection *c; TlsConnection *c;
Msg m; Msg m;
@ -641,7 +686,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, .
} }
memmove(c->crandom, m.u.clientHello.random, RandomSize); memmove(c->crandom, m.u.clientHello.random, RandomSize);
cipher = okCipher(m.u.clientHello.ciphers); cipher = okCipher(m.u.clientHello.ciphers, psklen > 0);
if(cipher < 0) { if(cipher < 0) {
// reply with EInsufficientSecurity if we know that's the case // reply with EInsufficientSecurity if we know that's the case
if(cipher == -2) if(cipher == -2)
@ -662,21 +707,27 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, .
csid = m.u.clientHello.sid; csid = m.u.clientHello.sid;
if(trace) if(trace)
trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len); trace(" cipher %x, compressor %x, csidlen %d\n", cipher, compressor, csid->len);
c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom); c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
if(c->sec == nil){ if(c->sec == nil){
tlsError(c, EHandshakeFailure, "can't initialize security: %r"); tlsError(c, EHandshakeFailure, "can't initialize security: %r");
goto Err; goto Err;
} }
c->sec->rpc = factotum_rsa_open(cert, certlen); if(psklen > 0){
if(c->sec->rpc == nil){ c->sec->psk = psk;
tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); c->sec->psklen = psklen;
goto Err;
} }
c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0); if(certlen > 0){
if(c->sec->rsapub == nil){ c->sec->rpc = factotum_rsa_open(cert, certlen);
tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate"); if(c->sec->rpc == nil){
goto Err; tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
goto Err;
}
c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
if(c->sec->rsapub == nil){
tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
goto Err;
}
} }
msgClear(&m); msgClear(&m);
@ -691,16 +742,18 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, .
goto Err; goto Err;
msgClear(&m); msgClear(&m);
m.tag = HCertificate; if(certlen > 0){
numcerts = countchain(chp); m.tag = HCertificate;
m.u.certificate.ncert = 1 + numcerts; numcerts = countchain(chp);
m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*)); m.u.certificate.ncert = 1 + numcerts;
m.u.certificate.certs[0] = makebytes(cert, certlen); m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
for (i = 0; i < numcerts && chp; i++, chp = chp->next) m.u.certificate.certs[0] = makebytes(cert, certlen);
m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen); for (i = 0; i < numcerts && chp; i++, chp = chp->next)
if(!msgSend(c, &m, AQueue)) m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
goto Err; if(!msgSend(c, &m, AQueue))
msgClear(&m); goto Err;
msgClear(&m);
}
m.tag = HServerHelloDone; m.tag = HServerHelloDone;
if(!msgSend(c, &m, AFlush)) if(!msgSend(c, &m, AFlush))
@ -713,10 +766,29 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, .
tlsError(c, EUnexpectedMessage, "expected a client key exchange"); tlsError(c, EUnexpectedMessage, "expected a client key exchange");
goto Err; goto Err;
} }
if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){ if(pskid != nil){
tlsError(c, EHandshakeFailure, "couldn't set secrets: %r"); if(m.u.clientKeyExchange.pskid == nil
|| m.u.clientKeyExchange.pskid->len != strlen(pskid)
|| memcmp(pskid, m.u.clientKeyExchange.pskid->data, m.u.clientKeyExchange.pskid->len) != 0){
tlsError(c, EUnknownPSKidentity, "unknown or missing pskid");
goto Err;
}
}
if(certlen > 0){
if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
goto Err;
}
} else if(psklen > 0){
if(tlsSecPSKs(c->sec, c->version) < 0){
tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
goto Err;
}
} else {
tlsError(c, EInternalError, "no psk or certificate");
goto Err; goto Err;
} }
setSecrets(c->sec, kd, c->nsecret); setSecrets(c->sec, kd, c->nsecret);
if(trace) if(trace)
trace("tls secrets\n"); trace("tls secrets\n");
@ -786,7 +858,8 @@ isDHE(int tlsid)
case TLS_DHE_RSA_WITH_AES_128_CBC_SHA: case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
case TLS_DHE_RSA_WITH_AES_256_CBC_SHA: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA: case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
case 0xCCAA: case 0xCC15: // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 case TLS_DHE_RSA_WITH_CHACHA20_POLY1305:
case GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305:
return 1; return 1;
} }
return 0; return 0;
@ -799,7 +872,20 @@ isECDHE(int tlsid)
case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256:
case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
case 0xCCA8: case 0xCC13: // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305:
case GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305:
return 1;
}
return 0;
}
static int
isPSK(int tlsid)
{
switch(tlsid){
case TLS_PSK_WITH_CHACHA20_POLY1305:
case TLS_PSK_WITH_AES_128_CBC_SHA256:
case TLS_PSK_WITH_AES_128_CBC_SHA:
return 1; return 1;
} }
return 0; return 0;
@ -980,8 +1066,18 @@ verifyDHparams(TlsConnection *c, Bytes *par, Bytes *sig, int sigalg)
RSApub *pk; RSApub *pk;
char *err; char *err;
if(sig == nil || sig->len <= 0) if(par == nil || par->len <= 0)
return "no dh parameters";
if(sig == nil || sig->len <= 0){
if(c->sec->psklen > 0)
return nil;
return "no signature"; return "no signature";
}
if(c->cert == nil)
return "no certificate";
pk = X509toRSApub(c->cert->data, c->cert->len, nil, 0); pk = X509toRSApub(c->cert->data, c->cert->len, nil, 0);
if(pk == nil) if(pk == nil)
@ -1015,7 +1111,11 @@ verifyDHparams(TlsConnection *c, Bytes *par, Bytes *sig, int sigalg)
} }
static TlsConnection * static TlsConnection *
tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, tlsClient2(int ctl, int hand,
uchar *csid, int ncsid,
uchar *cert, int certlen,
char *pskid, uchar *psk, int psklen,
uchar *ext, int extlen,
int (*trace)(char*fmt, ...)) int (*trace)(char*fmt, ...))
{ {
TlsConnection *c; TlsConnection *c;
@ -1036,17 +1136,24 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
c->trace = trace; c->trace = trace;
c->isClient = 1; c->isClient = 1;
c->clientVersion = c->version; c->clientVersion = c->version;
c->cert = nil;
c->sec = tlsSecInitc(c->clientVersion, c->crandom); c->sec = tlsSecInitc(c->clientVersion, c->crandom);
if(c->sec == nil) if(c->sec == nil)
goto Err; goto Err;
if(psklen > 0){
c->sec->psk = psk;
c->sec->psklen = psklen;
}
/* client hello */ /* client hello */
memset(&m, 0, sizeof(m)); memset(&m, 0, sizeof(m));
m.tag = HClientHello; m.tag = HClientHello;
m.u.clientHello.version = c->clientVersion; m.u.clientHello.version = c->clientVersion;
memmove(m.u.clientHello.random, c->crandom, RandomSize); memmove(m.u.clientHello.random, c->crandom, RandomSize);
m.u.clientHello.sid = makebytes(csid, ncsid); m.u.clientHello.sid = makebytes(csid, ncsid);
m.u.clientHello.ciphers = makeciphers(); m.u.clientHello.ciphers = makeciphers(psklen > 0);
m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors)); m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
m.u.clientHello.extensions = makebytes(ext, extlen); m.u.clientHello.extensions = makebytes(ext, extlen);
if(!msgSend(c, &m, AFlush)) if(!msgSend(c, &m, AFlush))
@ -1071,7 +1178,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
goto Err; goto Err;
} }
cipher = m.u.serverHello.cipher; cipher = m.u.serverHello.cipher;
if(!setAlgs(c, cipher)) { if((psklen > 0) != isPSK(cipher) || !setAlgs(c, cipher)) {
tlsError(c, EIllegalParameter, "invalid cipher suite"); tlsError(c, EIllegalParameter, "invalid cipher suite");
goto Err; goto Err;
} }
@ -1081,48 +1188,47 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
} }
msgClear(&m); msgClear(&m);
/* certificate */
if(!msgRecv(c, &m) || m.tag != HCertificate) {
tlsError(c, EUnexpectedMessage, "expected a certificate");
goto Err;
}
if(m.u.certificate.ncert < 1) {
tlsError(c, EIllegalParameter, "runt certificate");
goto Err;
}
c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
msgClear(&m);
/* server key exchange */
dhx = isDHE(cipher) || isECDHE(cipher); dhx = isDHE(cipher) || isECDHE(cipher);
if(!msgRecv(c, &m)) if(!msgRecv(c, &m))
goto Err; goto Err;
if(m.tag == HCertificate){
if(m.u.certificate.ncert < 1) {
tlsError(c, EIllegalParameter, "runt certificate");
goto Err;
}
c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
msgClear(&m);
if(!msgRecv(c, &m))
goto Err;
} else if(psklen == 0) {
tlsError(c, EUnexpectedMessage, "expected a certificate");
goto Err;
}
if(m.tag == HServerKeyExchange) { if(m.tag == HServerKeyExchange) {
char *err; if(dhx){
char *err = verifyDHparams(c,
if(!dhx){ m.u.serverKeyExchange.dh_parameters,
m.u.serverKeyExchange.dh_signature,
m.u.serverKeyExchange.sigalg);
if(err != nil){
tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err);
goto Err;
}
if(isECDHE(cipher))
epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
m.u.serverKeyExchange.curve,
m.u.serverKeyExchange.dh_Ys);
else
epm = tlsSecDHEc(c->sec, c->srandom, c->version,
m.u.serverKeyExchange.dh_p,
m.u.serverKeyExchange.dh_g,
m.u.serverKeyExchange.dh_Ys);
if(epm == nil)
goto Badcert;
} else if(psklen == 0){
tlsError(c, EUnexpectedMessage, "got an server key exchange"); tlsError(c, EUnexpectedMessage, "got an server key exchange");
goto Err; goto Err;
} }
err = verifyDHparams(c,
m.u.serverKeyExchange.dh_parameters,
m.u.serverKeyExchange.dh_signature,
m.u.serverKeyExchange.sigalg);
if(err != nil){
tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err);
goto Err;
}
if(isECDHE(cipher))
epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
m.u.serverKeyExchange.curve,
m.u.serverKeyExchange.dh_Ys);
else
epm = tlsSecDHEc(c->sec, c->srandom, c->version,
m.u.serverKeyExchange.dh_p,
m.u.serverKeyExchange.dh_g,
m.u.serverKeyExchange.dh_Ys);
if(epm == nil)
goto Badcert;
msgClear(&m); msgClear(&m);
if(!msgRecv(c, &m)) if(!msgRecv(c, &m))
goto Err; goto Err;
@ -1146,14 +1252,22 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
} }
msgClear(&m); msgClear(&m);
if(!dhx) if(!dhx){
epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom, if(c->cert != nil){
c->cert->data, c->cert->len, c->version); epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
c->cert->data, c->cert->len, c->version);
if(epm == nil){ if(epm == nil){
Badcert: Badcert:
tlsError(c, EBadCertificate, "bad certificate: %r"); tlsError(c, EBadCertificate, "bad certificate: %r");
goto Err; goto Err;
}
} else if(psklen > 0) {
if(tlsSecPSKc(c->sec, c->srandom, c->version) < 0)
goto Badcert;
} else {
tlsError(c, EInternalError, "no psk or certificate");
goto Err;
}
} }
setSecrets(c->sec, kd, c->nsecret); setSecrets(c->sec, kd, c->nsecret);
@ -1182,12 +1296,13 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
/* client key exchange */ /* client key exchange */
m.tag = HClientKeyExchange; m.tag = HClientKeyExchange;
if(psklen > 0){
if(pskid == nil)
pskid = "";
m.u.clientKeyExchange.pskid = makebytes((uchar*)pskid, strlen(pskid));
}
m.u.clientKeyExchange.key = epm; m.u.clientKeyExchange.key = epm;
epm = nil; epm = nil;
if(m.u.clientKeyExchange.key == nil) {
tlsError(c, EHandshakeFailure, "can't set secret: %r");
goto Err;
}
if(!msgSend(c, &m, AFlush)) if(!msgSend(c, &m, AFlush))
goto Err; goto Err;
@ -1423,8 +1538,17 @@ msgSend(TlsConnection *c, Msg *m, int act)
p += 2; p += 2;
memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len); memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
p += m->u.certificateVerify.signature->len; p += m->u.certificateVerify.signature->len;
break; break;
case HClientKeyExchange: case HClientKeyExchange:
if(m->u.clientKeyExchange.pskid != nil){
n = m->u.clientKeyExchange.pskid->len;
put16(p, n);
p += 2;
memmove(p, m->u.clientKeyExchange.pskid->data, n);
p += n;
}
if(m->u.clientKeyExchange.key == nil)
break;
n = m->u.clientKeyExchange.key->len; n = m->u.clientKeyExchange.key->len;
if(c->version != SSL3Version){ if(c->version != SSL3Version){
if(isECDHE(c->cipher)) if(isECDHE(c->cipher))
@ -1737,6 +1861,18 @@ msgRecv(TlsConnection *c, Msg *m)
case HServerHelloDone: case HServerHelloDone:
break; break;
case HServerKeyExchange: case HServerKeyExchange:
if(isPSK(c->cipher)){
if(n < 2)
goto Short;
nn = get16(p);
p += 2, n -= 2;
if(nn > n)
goto Short;
m->u.serverKeyExchange.pskid = makebytes(p, nn);
p += nn, n -= nn;
if(n == 0)
break;
}
if(n < 2) if(n < 2)
goto Short; goto Short;
s = p; s = p;
@ -1805,6 +1941,18 @@ msgRecv(TlsConnection *c, Msg *m)
* this message depends upon the encryption selected * this message depends upon the encryption selected
* assume rsa. * assume rsa.
*/ */
if(isPSK(c->cipher)){
if(n < 2)
goto Short;
nn = get16(p);
p += 2, n -= 2;
if(nn > n)
goto Short;
m->u.clientKeyExchange.pskid = makebytes(p, nn);
p += nn, n -= nn;
if(n == 0)
break;
}
if(c->version == SSL3Version) if(c->version == SSL3Version)
nn = n; nn = n;
else{ else{
@ -1883,6 +2031,7 @@ msgClear(Msg *m)
case HServerHelloDone: case HServerHelloDone:
break; break;
case HServerKeyExchange: case HServerKeyExchange:
freebytes(m->u.serverKeyExchange.pskid);
freebytes(m->u.serverKeyExchange.dh_p); freebytes(m->u.serverKeyExchange.dh_p);
freebytes(m->u.serverKeyExchange.dh_g); freebytes(m->u.serverKeyExchange.dh_g);
freebytes(m->u.serverKeyExchange.dh_Ys); freebytes(m->u.serverKeyExchange.dh_Ys);
@ -1890,6 +2039,7 @@ msgClear(Msg *m)
freebytes(m->u.serverKeyExchange.dh_signature); freebytes(m->u.serverKeyExchange.dh_signature);
break; break;
case HClientKeyExchange: case HClientKeyExchange:
freebytes(m->u.clientKeyExchange.pskid);
freebytes(m->u.clientKeyExchange.key); freebytes(m->u.clientKeyExchange.key);
break; break;
case HFinished: case HFinished:
@ -1998,6 +2148,10 @@ msgPrint(char *buf, int n, Msg *m)
break; break;
case HServerKeyExchange: case HServerKeyExchange:
bs = seprint(bs, be, "HServerKeyExchange\n"); bs = seprint(bs, be, "HServerKeyExchange\n");
if(m->u.serverKeyExchange.pskid != nil)
bs = bytesPrint(bs, be, "\tpskid: ", m->u.serverKeyExchange.pskid, "\n");
if(m->u.serverKeyExchange.dh_parameters == nil)
break;
if(m->u.serverKeyExchange.curve != 0){ if(m->u.serverKeyExchange.curve != 0){
bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve); bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
} else { } else {
@ -2012,7 +2166,10 @@ msgPrint(char *buf, int n, Msg *m)
break; break;
case HClientKeyExchange: case HClientKeyExchange:
bs = seprint(bs, be, "HClientKeyExchange\n"); bs = seprint(bs, be, "HClientKeyExchange\n");
bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n"); if(m->u.clientKeyExchange.pskid != nil)
bs = bytesPrint(bs, be, "\tpskid: ", m->u.clientKeyExchange.pskid, "\n");
if(m->u.clientKeyExchange.key != nil)
bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
break; break;
case HFinished: case HFinished:
bs = seprint(bs, be, "HFinished\n"); bs = seprint(bs, be, "HFinished\n");
@ -2137,7 +2294,7 @@ setAlgs(TlsConnection *c, int a)
} }
static int static int
okCipher(Ints *cv) okCipher(Ints *cv, int ispsk)
{ {
int weak, i, j, c; int weak, i, j, c;
@ -2148,6 +2305,8 @@ okCipher(Ints *cv)
weak = 0; weak = 0;
else else
weak &= weakCipher[c]; weak &= weakCipher[c];
if(isPSK(c) != ispsk)
continue;
if(isDHE(c) || isECDHE(c)) if(isDHE(c) || isECDHE(c))
continue; /* TODO: not implemented for server */ continue; /* TODO: not implemented for server */
for(j = 0; j < nelem(cipherAlgs); j++) for(j = 0; j < nelem(cipherAlgs); j++)
@ -2243,17 +2402,17 @@ initCiphers(void)
} }
static Ints* static Ints*
makeciphers(void) makeciphers(int ispsk)
{ {
Ints *is; Ints *is;
int i, j; int i, j;
is = newints(nciphers); is = newints(nciphers);
j = 0; j = 0;
for(i = 0; i < nelem(cipherAlgs); i++){ for(i = 0; i < nelem(cipherAlgs); i++)
if(cipherAlgs[i].ok) if(cipherAlgs[i].ok && isPSK(cipherAlgs[i].tlsid) == ispsk)
is->data[j++] = cipherAlgs[i].tlsid; is->data[j++] = cipherAlgs[i].tlsid;
} is->len = j;
return is; return is;
} }
@ -2489,6 +2648,17 @@ Err:
return -1; return -1;
} }
static int
tlsSecPSKs(TlsSec *sec, int vers)
{
if(setVers(sec, vers) < 0){
sec->ok = -1;
return -1;
}
setMasterSecret(sec, newbytes(sec->psklen));
return 0;
}
static TlsSec* static TlsSec*
tlsSecInitc(int cvers, uchar *crandom) tlsSecInitc(int cvers, uchar *crandom)
{ {
@ -2500,6 +2670,18 @@ tlsSecInitc(int cvers, uchar *crandom)
return sec; return sec;
} }
static int
tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers)
{
memmove(sec->srandom, srandom, RandomSize);
if(setVers(sec, vers) < 0){
sec->ok = -1;
return -1;
}
setMasterSecret(sec, newbytes(sec->psklen));
return 0;
}
static Bytes* static Bytes*
tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers) tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
{ {
@ -2608,6 +2790,22 @@ setSecrets(TlsSec *sec, uchar *kd, int nkd)
static void static void
setMasterSecret(TlsSec *sec, Bytes *pm) setMasterSecret(TlsSec *sec, Bytes *pm)
{ {
if(sec->psklen > 0){
Bytes *opm = pm;
uchar *p;
/* concatenate psk to pre-master secret */
pm = newbytes(4 + opm->len + sec->psklen);
p = pm->data;
put16(p, opm->len), p += 2;
memmove(p, opm->data, opm->len), p += opm->len;
put16(p, sec->psklen), p += 2;
memmove(p, sec->psk, sec->psklen);
memset(opm->data, 0, opm->len);
freebytes(opm);
}
(*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret", (*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
sec->crandom, RandomSize, sec->srandom, RandomSize); sec->crandom, RandomSize, sec->srandom, RandomSize);