diff --git a/sys/src/libsec/port/tlshand.c b/sys/src/libsec/port/tlshand.c index 8603fd155..51f9b72f3 100644 --- a/sys/src/libsec/port/tlshand.c +++ b/sys/src/libsec/port/tlshand.c @@ -20,14 +20,11 @@ enum { MaxChunk = 1<<15, MAXdlen = SHA2_512dlen, RandomSize = 32, - SidSize = 32, MasterSecretSize = 48, AQueue = 0, AFlush = 1, }; -typedef struct TlsSec TlsSec; - typedef struct Bytes{ int len; uchar data[1]; // [len] @@ -62,20 +59,38 @@ typedef struct HandshakeHash { SHA2_256state sha2_256; } HandshakeHash; +typedef struct TlsSec TlsSec; +struct TlsSec { + RSApub *rsapub; + AuthRpc *rpc; // factotum for rsa private key + uchar *psk; // pre-shared key + int psklen; + int clientVers; // version in ClientHello + uchar sec[MasterSecretSize]; // master secret + uchar crandom[RandomSize]; // client random + uchar srandom[RandomSize]; // server random + // byte generation and handshake checksum + void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int); + void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int); + int nfin; +}; + typedef struct TlsConnection{ - TlsSec *sec; // security management goo + TlsSec sec[1]; // security management goo int hand, ctl; // record layer file descriptors int erred; // set when tlsError called int (*trace)(char*fmt, ...); // for debugging int version; // protocol we are speaking - int verset; // version has been set - int ver2hi; // server got a version 2 hello - int isClient; // is this the client or server? - Bytes *sid; // SessionID Bytes *cert; // server certificate; only last - no chain - Lock statelk; - int state; // must be set using setstate + int cipher; + int nsecret; // amount of secret data to init keys + char *digest; // name of digest algorithm to use + char *enc; // name of encryption algorithm to use + + // for finished messages + HandshakeHash handhash; + Finished finished; // input buffer for handshake messages uchar recvbuf[MaxChunk]; @@ -84,18 +99,6 @@ typedef struct TlsConnection{ // output buffer uchar sendbuf[MaxChunk]; uchar *sendp; - - uchar crandom[RandomSize]; // client random - uchar srandom[RandomSize]; // server random - int clientVersion; // version in ClientHello - int cipher; - char *digest; // name of digest algorithm to use - char *enc; // name of encryption algorithm to use - int nsecret; // amount of secret data to init keys - - // for finished messages - HandshakeHash handhash; - Finished finished; } TlsConnection; typedef struct Msg{ @@ -149,24 +152,6 @@ typedef struct Msg{ } u; } Msg; -typedef struct TlsSec{ - char *server; // name of remote; nil for server - int ok; // <0 killed; == 0 in progress; >0 reusable - RSApub *rsapub; - AuthRpc *rpc; // factotum for rsa private key - uchar *psk; // pre-shared key - int psklen; - uchar sec[MasterSecretSize]; // master secret - uchar crandom[RandomSize]; // client random - uchar srandom[RandomSize]; // server random - int clientVers; // version in ClientHello - int vers; // final version - // byte generation and handshake checksum - void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int); - void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int); - int nfin; -} TlsSec; - enum { SSL3Version = 0x0300, @@ -386,7 +371,6 @@ static TlsConnection *tlsServer2(int ctl, int hand, char *pskid, uchar *psk, int psklen, int (*trace)(char*fmt, ...), PEMChain *chain); static TlsConnection *tlsClient2(int ctl, int hand, - uchar *csid, int ncsid, uchar *cert, int certlen, char *pskid, uchar *psk, int psklen, uchar *ext, int extlen, int (*trace)(char*fmt, ...)); @@ -397,6 +381,7 @@ static int msgSend(TlsConnection *c, Msg *m, int act); static void tlsError(TlsConnection *c, int err, char *msg, ...); #pragma varargck argpos tlsError 3 static int setVersion(TlsConnection *c, int version); +static int setSecrets(TlsConnection *c, int isclient); static int finishedMatch(TlsConnection *c, Finished *f); static void tlsConnectionFree(TlsConnection *c); @@ -406,27 +391,19 @@ static int okCompression(Bytes *cv); static int initCiphers(void); static Ints* makeciphers(int ispsk); -static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom); -static int tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm); -static int tlsSecPSKs(TlsSec *sec, int vers); -static TlsSec* tlsSecInitc(int cvers, uchar *crandom); -static Bytes* tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers); -static int tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers); -static Bytes* tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys); -static Bytes* tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys); +static void tlsSecInits(TlsSec *sec, int cvers, uchar *crandom); +static int tlsSecRSAs(TlsSec *sec, Bytes *epm); +static void tlsSecPSKs(TlsSec *sec); +static void tlsSecInitc(TlsSec *sec, int cvers); +static Bytes* tlsSecRSAc(TlsSec *sec, uchar *cert, int ncert); +static void tlsSecPSKc(TlsSec *sec); +static Bytes* tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys); +static Bytes* tlsSecECDHEc(TlsSec *sec, int curve, Bytes *Ys); +static void tlsSecVers(TlsSec *sec, int v); static int tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient); -static void tlsSecOk(TlsSec *sec); -static void tlsSecClose(TlsSec *sec); static void setMasterSecret(TlsSec *sec, Bytes *pm); -static void setSecrets(TlsSec *sec, uchar *kd, int nkd); static Bytes* pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype); static Bytes* pkcs1_decrypt(TlsSec *sec, Bytes *cipher); -static void tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient); -static void tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient); -static void sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient); -static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, - uchar *seed0, int nseed0, uchar *seed1, int nseed1); -static int setVers(TlsSec *sec, int version); static AuthRpc* factotum_rsa_open(RSApub *rsapub); static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher); @@ -482,26 +459,27 @@ tlsServer(int fd, TLSconn *conn) close(ctl); return -1; } + data = -1; fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->pskID, conn->psk, conn->psklen, conn->trace, conn->chain); - snprint(dname, sizeof(dname), "#a/tls/%s/data", buf); - data = open(dname, ORDWR); + if(tls != nil){ + snprint(dname, sizeof(dname), "#a/tls/%s/data", buf); + data = open(dname, ORDWR); + } close(hand); close(ctl); - if(data < 0 || tls == nil){ - if(tls != nil) - tlsConnectionFree(tls); + if(data < 0){ + tlsConnectionFree(tls); return -1; } free(conn->cert); conn->cert = nil; // client certificates are not yet implemented conn->certlen = 0; - conn->sessionIDlen = tls->sid->len; - conn->sessionID = emalloc(conn->sessionIDlen); - memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen); + conn->sessionIDlen = 0; + conn->sessionID = nil; if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0) @@ -624,7 +602,6 @@ tlsClient(int fd, TLSconn *conn) fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); ext = tlsClientExtensions(conn, &n); tls = tlsClient2(ctl, hand, - conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, conn->pskID, conn->psk, conn->psklen, ext, n, conn->trace); @@ -635,6 +612,7 @@ tlsClient(int fd, TLSconn *conn) close(data); return -1; } + free(conn->cert); if(tls->cert != nil){ conn->certlen = tls->cert->len; conn->cert = emalloc(conn->certlen); @@ -643,9 +621,8 @@ tlsClient(int fd, TLSconn *conn) conn->certlen = 0; conn->cert = nil; } - conn->sessionIDlen = tls->sid->len; - conn->sessionID = emalloc(conn->sessionIDlen); - memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen); + conn->sessionIDlen = 0; + conn->sessionID = nil; if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0) @@ -680,15 +657,13 @@ tlsServer2(int ctl, int hand, { TlsConnection *c; Msg m; - Bytes *csid; - uchar sid[SidSize], kd[MaxKeyData]; - char *secrets; - int cipher, compressor, nsid, rv, numcerts, i; + int cipher, compressor, numcerts, i; if(trace) trace("tlsServer2\n"); if(!initCiphers()) return nil; + c = emalloc(sizeof(TlsConnection)); c->ctl = ctl; c->hand = hand; @@ -705,15 +680,13 @@ tlsServer2(int ctl, int hand, tlsError(c, EUnexpectedMessage, "expected a client hello"); goto Err; } - c->clientVersion = m.u.clientHello.version; if(trace) - trace("ClientHello version %x\n", c->clientVersion); - if(setVersion(c, c->clientVersion) < 0) { + trace("ClientHello version %x\n", m.u.clientHello.version); + if(setVersion(c, m.u.clientHello.version) < 0) { tlsError(c, EIllegalParameter, "incompatible version"); goto Err; } - memmove(c->crandom, m.u.clientHello.random, RandomSize); cipher = okCipher(m.u.clientHello.ciphers, psklen > 0); if(cipher < 0 || !setAlgs(c, cipher)) { tlsError(c, EHandshakeFailure, "no matching cipher suite"); @@ -724,11 +697,11 @@ tlsServer2(int ctl, int hand, tlsError(c, EHandshakeFailure, "no matching compressor"); goto Err; } - - csid = m.u.clientHello.sid; if(trace) - trace(" cipher %x, compressor %x, csidlen %d\n", cipher, compressor, csid->len); - c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom); + trace(" cipher %x, compressor %x\n", cipher, compressor); + + tlsSecInits(c->sec, m.u.clientHello.version, m.u.clientHello.random); + tlsSecVers(c->sec, c->version); if(psklen > 0){ c->sec->psk = psk; c->sec->psklen = psklen; @@ -750,14 +723,12 @@ tlsServer2(int ctl, int hand, m.tag = HServerHello; m.u.serverHello.version = c->version; - memmove(m.u.serverHello.random, c->srandom, RandomSize); + memmove(m.u.serverHello.random, c->sec->srandom, RandomSize); m.u.serverHello.cipher = cipher; m.u.serverHello.compressor = compressor; - c->sid = makebytes(sid, nsid); - m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len); + m.u.serverHello.sid = makebytes(nil, 0); if(!msgSend(c, &m, AQueue)) goto Err; - msgClear(&m); if(certlen > 0){ m.tag = HCertificate; @@ -769,13 +740,11 @@ tlsServer2(int ctl, int hand, m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen); if(!msgSend(c, &m, AQueue)) goto Err; - msgClear(&m); } m.tag = HServerHelloDone; if(!msgSend(c, &m, AFlush)) goto Err; - msgClear(&m); if(!msgRecv(c, &m)) goto Err; @@ -792,34 +761,23 @@ tlsServer2(int ctl, int hand, } } if(certlen > 0){ - if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){ - tlsError(c, EHandshakeFailure, "couldn't set secrets: %r"); + if(tlsSecRSAs(c->sec, m.u.clientKeyExchange.key) < 0){ + tlsError(c, EHandshakeFailure, "couldn't set keys: %r"); goto Err; } } else if(psklen > 0){ - if(tlsSecPSKs(c->sec, c->version) < 0){ - tlsError(c, EHandshakeFailure, "couldn't set secrets: %r"); - goto Err; - } + tlsSecPSKs(c->sec); } else { tlsError(c, EInternalError, "no psk or certificate"); goto Err; } - setSecrets(c->sec, kd, c->nsecret); if(trace) trace("tls secrets\n"); - secrets = (char*)emalloc(2*c->nsecret); - enc64(secrets, 2*c->nsecret, kd, c->nsecret); - rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets); - memset(secrets, 0, 2*c->nsecret); - free(secrets); - memset(kd, 0, c->nsecret); - if(rv < 0){ - tlsError(c, EHandshakeFailure, "can't set keys: %r"); + if(setSecrets(c, 0) < 0){ + tlsError(c, EHandshakeFailure, "can't set secrets: %r"); goto Err; } - msgClear(&m); /* no CertificateVerify; skip to Finished */ if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){ @@ -852,19 +810,17 @@ tlsServer2(int ctl, int hand, m.u.finished = c->finished; if(!msgSend(c, &m, AFlush)) goto Err; - msgClear(&m); if(trace) trace("tls finished\n"); if(fprint(c->ctl, "opened") < 0) goto Err; - tlsSecOk(c->sec); return c; Err: msgClear(&m); tlsConnectionFree(c); - return 0; + return nil; } static int @@ -918,8 +874,7 @@ isPSK(int tlsid) } static Bytes* -tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, - Bytes *p, Bytes *g, Bytes *Ys) +tlsSecDHEc(TlsSec *sec, Bytes *p, Bytes *g, Bytes *Ys) { mpint *G, *P, *Y, *K; Bytes *epm; @@ -928,10 +883,6 @@ tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, if(p == nil || g == nil || Ys == nil) return nil; - memmove(sec->srandom, srandom, RandomSize); - if(setVers(sec, vers) < 0) - return nil; - epm = nil; P = bytestomp(p); G = bytestomp(g); @@ -959,7 +910,7 @@ Out: } static Bytes* -tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys) +tlsSecECDHEc(TlsSec *sec, int curve, Bytes *Ys) { Namedcurve *nc, *enc; Bytes *epm; @@ -978,10 +929,6 @@ tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys) if(nc == enc) return nil; - - memmove(sec->srandom, srandom, RandomSize); - if(setVers(sec, vers) < 0) - return nil; ecdominit(&dom, nc->init); pub = ecdecodepub(&dom, Ys->data, Ys->len); @@ -1031,7 +978,7 @@ verifyDHparams(TlsConnection *c, Bytes *par, Bytes *sig, int sigalg) char *err; if(par == nil || par->len <= 0) - return "no dh parameters"; + return "no DH parameters"; if(sig == nil || sig->len <= 0){ if(c->sec->psklen > 0) @@ -1043,8 +990,8 @@ verifyDHparams(TlsConnection *c, Bytes *par, Bytes *sig, int sigalg) return "no certificate"; blob = newbytes(2*RandomSize + par->len); - memmove(blob->data+0*RandomSize, c->crandom, RandomSize); - memmove(blob->data+1*RandomSize, c->srandom, RandomSize); + memmove(blob->data+0*RandomSize, c->sec->crandom, RandomSize); + memmove(blob->data+1*RandomSize, c->sec->srandom, RandomSize); memmove(blob->data+2*RandomSize, par->data, par->len); if(c->version < TLS12Version){ digestlen = MD5dlen + SHA1dlen; @@ -1089,7 +1036,6 @@ verifyDHparams(TlsConnection *c, Bytes *par, Bytes *sig, int sigalg) static TlsConnection * tlsClient2(int ctl, int hand, - uchar *csid, int ncsid, uchar *cert, int certlen, char *pskid, uchar *psk, int psklen, uchar *ext, int extlen, @@ -1097,25 +1043,23 @@ tlsClient2(int ctl, int hand, { TlsConnection *c; Msg m; - uchar kd[MaxKeyData]; - char *secrets; - int creq, dhx, rv, cipher; + int creq, dhx, cipher; Bytes *epm; if(!initCiphers()) return nil; + epm = nil; + memset(&m, 0, sizeof(m)); c = emalloc(sizeof(TlsConnection)); - c->version = ProtocolVersion; c->ctl = ctl; c->hand = hand; c->trace = trace; - c->isClient = 1; - c->clientVersion = c->version; c->cert = nil; - c->sec = tlsSecInitc(c->clientVersion, c->crandom); + c->version = ProtocolVersion; + tlsSecInitc(c->sec, c->version); if(psklen > 0){ c->sec->psk = psk; c->sec->psklen = psklen; @@ -1124,28 +1068,26 @@ tlsClient2(int ctl, int hand, /* client certificate */ c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0); if(c->sec->rsapub == nil){ - tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate"); + tlsError(c, EInternalError, "invalid X509/rsa certificate"); goto Err; } c->sec->rpc = factotum_rsa_open(c->sec->rsapub); if(c->sec->rpc == nil){ - tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); + tlsError(c, EInternalError, "factotum_rsa_open: %r"); goto Err; } } /* client hello */ - memset(&m, 0, sizeof(m)); m.tag = HClientHello; - m.u.clientHello.version = c->clientVersion; - memmove(m.u.clientHello.random, c->crandom, RandomSize); - m.u.clientHello.sid = makebytes(csid, ncsid); + m.u.clientHello.version = c->version; + memmove(m.u.clientHello.random, c->sec->crandom, RandomSize); + m.u.clientHello.sid = makebytes(nil, 0); m.u.clientHello.ciphers = makeciphers(psklen > 0); m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors)); m.u.clientHello.extensions = makebytes(ext, extlen); if(!msgSend(c, &m, AFlush)) goto Err; - msgClear(&m); /* server hello */ if(!msgRecv(c, &m)) @@ -1158,12 +1100,9 @@ tlsClient2(int ctl, int hand, tlsError(c, EIllegalParameter, "incompatible version: %r"); goto Err; } - memmove(c->srandom, m.u.serverHello.random, RandomSize); - c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len); - if(c->sid->len != 0 && c->sid->len != SidSize) { - tlsError(c, EIllegalParameter, "invalid server session identifier"); - goto Err; - } + tlsSecVers(c->sec, c->version); + memmove(c->sec->srandom, m.u.serverHello.random, RandomSize); + cipher = m.u.serverHello.cipher; if((psklen > 0) != isPSK(cipher) || !setAlgs(c, cipher)) { tlsError(c, EIllegalParameter, "invalid cipher suite"); @@ -1173,7 +1112,6 @@ tlsClient2(int ctl, int hand, tlsError(c, EIllegalParameter, "invalid compression"); goto Err; } - msgClear(&m); dhx = isDHE(cipher) || isECDHE(cipher); if(!msgRecv(c, &m)) @@ -1184,7 +1122,6 @@ tlsClient2(int ctl, int hand, goto Err; } c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len); - msgClear(&m); if(!msgRecv(c, &m)) goto Err; } else if(psklen == 0) { @@ -1198,25 +1135,26 @@ tlsClient2(int ctl, int hand, m.u.serverKeyExchange.dh_signature, m.u.serverKeyExchange.sigalg); if(err != nil){ - tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err); + tlsError(c, EBadCertificate, "can't verify DH parameters: %s", err); goto Err; } if(isECDHE(cipher)) - epm = tlsSecECDHEc(c->sec, c->srandom, c->version, + epm = tlsSecECDHEc(c->sec, m.u.serverKeyExchange.curve, m.u.serverKeyExchange.dh_Ys); else - epm = tlsSecDHEc(c->sec, c->srandom, c->version, + epm = tlsSecDHEc(c->sec, m.u.serverKeyExchange.dh_p, m.u.serverKeyExchange.dh_g, m.u.serverKeyExchange.dh_Ys); - if(epm == nil) - goto Badcert; + if(epm == nil){ + tlsError(c, EHandshakeFailure, "bad DH parameters"); + goto Err; + } } else if(psklen == 0){ tlsError(c, EUnexpectedMessage, "got an server key exchange"); goto Err; } - msgClear(&m); if(!msgRecv(c, &m)) goto Err; } else if(dhx){ @@ -1228,7 +1166,6 @@ tlsClient2(int ctl, int hand, creq = 0; if(m.tag == HCertificateRequest) { creq = 1; - msgClear(&m); if(!msgRecv(c, &m)) goto Err; } @@ -1241,44 +1178,35 @@ tlsClient2(int ctl, int hand, if(!dhx){ if(c->cert != nil){ - epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom, - c->cert->data, c->cert->len, c->version); + epm = tlsSecRSAc(c->sec, c->cert->data, c->cert->len); if(epm == nil){ - Badcert: tlsError(c, EBadCertificate, "bad certificate: %r"); goto Err; } - } else if(psklen > 0) { - if(tlsSecPSKc(c->sec, c->srandom, c->version) < 0) - goto Badcert; + } else if(psklen > 0){ + tlsSecPSKc(c->sec); } else { tlsError(c, EInternalError, "no psk or certificate"); goto Err; } } - setSecrets(c->sec, kd, c->nsecret); - secrets = (char*)emalloc(2*c->nsecret); - enc64(secrets, 2*c->nsecret, kd, c->nsecret); - rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets); - memset(secrets, 0, 2*c->nsecret); - free(secrets); - memset(kd, 0, c->nsecret); - if(rv < 0){ - tlsError(c, EHandshakeFailure, "can't set keys: %r"); + if(trace) + trace("tls secrets\n"); + if(setSecrets(c, 1) < 0){ + tlsError(c, EHandshakeFailure, "can't set secrets: %r"); goto Err; } if(creq) { + m.tag = HCertificate; if(certlen > 0){ m.u.certificate.ncert = 1; m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*)); m.u.certificate.certs[0] = makebytes(cert, certlen); } - m.tag = HCertificate; if(!msgSend(c, &m, AFlush)) goto Err; - msgClear(&m); } /* client key exchange */ @@ -1293,7 +1221,6 @@ tlsClient2(int ctl, int hand, if(!msgSend(c, &m, AFlush)) goto Err; - msgClear(&m); /* certificate verify */ if(creq && certlen > 0) { @@ -1334,7 +1261,6 @@ tlsClient2(int ctl, int hand, m.tag = HCertificateVerify; if(!msgSend(c, &m, AFlush)) goto Err; - msgClear(&m); } /* change cipher spec */ @@ -1355,7 +1281,6 @@ tlsClient2(int ctl, int hand, tlsError(c, EInternalError, "can't flush after client Finished: %r"); goto Err; } - msgClear(&m); if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){ tlsError(c, EInternalError, "can't set finished 0: %r"); @@ -1381,7 +1306,6 @@ tlsClient2(int ctl, int hand, trace("unable to do final open: %r\n"); goto Err; } - tlsSecOk(c->sec); return c; Err: @@ -1433,13 +1357,11 @@ msgSend(TlsConnection *c, Msg *m, int act) // sid n = m->u.clientHello.sid->len; - assert(n < 256); p[0] = n; memmove(p+1, m->u.clientHello.sid->data, n); p += n+1; n = m->u.clientHello.ciphers->len; - assert(n > 0 && n < 200); put16(p, n*2); p += 2; for(i=0; iu.clientHello.compressors->len; - assert(n > 0); p[0] = n; memmove(p+1, m->u.clientHello.compressors->data, n); p += n+1; @@ -1472,7 +1393,6 @@ msgSend(TlsConnection *c, Msg *m, int act) // sid n = m->u.serverHello.sid->len; - assert(n < 256); p[0] = n; memmove(p+1, m->u.serverHello.sid->data, n); p += n+1; @@ -1548,7 +1468,7 @@ msgSend(TlsConnection *c, Msg *m, int act) // go back and fill in size n = p - c->sendp; - assert(p <= c->sendbuf + sizeof(c->sendbuf)); + assert(n <= sizeof(c->sendbuf)); put24(c->sendp+1, n-4); // remember hash of Handshake messages @@ -1599,8 +1519,9 @@ static int msgRecv(TlsConnection *c, Msg *m) { uchar *p, *s; - int type, n, nn, i, nsid, nrandom, nciph; + int type, n, nn, i; + msgClear(m); for(;;) { p = tlsReadN(c, 4); if(p == nil) @@ -1625,6 +1546,8 @@ msgRecv(TlsConnection *c, Msg *m) /* Cope with an SSL3 ClientHello expressed in SSL2 record format. This is sent by some clients that we must interoperate with, such as Java's JSSE and Microsoft's Internet Explorer. */ + int nsid, nrandom, nciph; + p = tlsReadN(c, n); if(p == nil) return 0; @@ -1683,14 +1606,12 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 2) goto Short; m->u.clientHello.version = get16(p); - p += 2; - n -= 2; + p += 2, n -= 2; if(n < RandomSize) goto Short; memmove(m->u.clientHello.random, p, RandomSize); - p += RandomSize; - n -= RandomSize; + p += RandomSize, n -= RandomSize; if(n < 1 || n < p[0]+1) goto Short; m->u.clientHello.sid = makebytes(p+1, p[0]); @@ -1700,23 +1621,20 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 2) goto Short; nn = get16(p); - p += 2; - n -= 2; + p += 2, n -= 2; if((nn & 1) || n < nn || nn < 2) goto Short; m->u.clientHello.ciphers = newints(nn >> 1); for(i = 0; i < nn; i += 2) m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]); - p += nn; - n -= nn; + p += nn, n -= nn; if(n < 1 || n < p[0]+1 || p[0] == 0) goto Short; nn = p[0]; m->u.clientHello.compressors = makebytes(p+1, nn); - p += nn + 1; - n -= nn + 1; + p += nn + 1, n -= nn + 1; if(n < 2) break; @@ -1730,14 +1648,12 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 2) goto Short; m->u.serverHello.version = get16(p); - p += 2; - n -= 2; + p += 2, n -= 2; if(n < RandomSize) goto Short; memmove(m->u.serverHello.random, p, RandomSize); - p += RandomSize; - n -= RandomSize; + p += RandomSize, n -= RandomSize; if(n < 1 || n < p[0]+1) goto Short; @@ -1749,8 +1665,7 @@ msgRecv(TlsConnection *c, Msg *m) goto Short; m->u.serverHello.cipher = get16(p); m->u.serverHello.compressor = p[2]; - p += 3; - n -= 3; + p += 3, n -= 3; if(n < 2) break; @@ -1764,8 +1679,7 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 3) goto Short; nn = get24(p); - p += 3; - n -= 3; + p += 3, n -= 3; if(nn == 0 && n > 0) goto Short; /* certs */ @@ -1774,15 +1688,13 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 3) goto Short; nn = get24(p); - p += 3; - n -= 3; + p += 3, n -= 3; if(nn > n) goto Short; m->u.certificate.ncert = i+1; m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes*)); m->u.certificate.certs[i] = makebytes(p, nn); - p += nn; - n -= nn; + p += nn, n -= nn; i++; } break; @@ -1790,33 +1702,28 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 1) goto Short; nn = p[0]; - p += 1; - n -= 1; + p++, n--; if(nn > n) goto Short; m->u.certificateRequest.types = makebytes(p, nn); - p += nn; - n -= nn; + p += nn, n -= nn; if(c->version >= TLS12Version){ if(n < 2) goto Short; nn = get16(p); - p += 2; - n -= 2; + p += 2, n -= 2; if(nn & 1) goto Short; m->u.certificateRequest.sigalgs = newints(nn>>1); for(i = 0; i < nn; i += 2) m->u.certificateRequest.sigalgs->data[i >> 1] = get16(&p[i]); - p += nn; - n -= nn; + p += nn, n -= nn; } if(n < 2) goto Short; nn = get16(p); - p += 2; - n -= 2; + p += 2, n -= 2; /* nn == 0 can happen; yahoo's servers do it */ if(nn != n) goto Short; @@ -1826,16 +1733,14 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 2) goto Short; nn = get16(p); - p += 2; - n -= 2; + p += 2, n -= 2; if(nn < 1 || nn > n) goto Short; m->u.certificateRequest.nca = i+1; m->u.certificateRequest.cas = erealloc( m->u.certificateRequest.cas, (i+1)*sizeof(Bytes*)); m->u.certificateRequest.cas[i] = makebytes(p, nn); - p += nn; - n -= nn; + p += nn, n -= nn; i++; } break; @@ -1940,8 +1845,7 @@ msgRecv(TlsConnection *c, Msg *m) if(n < 2) goto Short; nn = get16(p); - p += 2; - n -= 2; + p += 2, n -= 2; } if(n < nn) goto Short; @@ -1980,8 +1884,6 @@ msgClear(Msg *m) int i; switch(m->tag) { - default: - sysfatal("msgClear: unknown message type: %d", m->tag); case HHelloRequest: break; case HClientHello: @@ -2186,7 +2088,7 @@ tlsError(TlsConnection *c, int err, char *fmt, ...) static int setVersion(TlsConnection *c, int version) { - if(c->verset || version > MaxProtoVersion || version < MinProtoVersion) + if(version > MaxProtoVersion || version < MinProtoVersion) return -1; if(version > c->version) version = c->version; @@ -2197,7 +2099,6 @@ setVersion(TlsConnection *c, int version) c->version = version; c->finished.n = TLSFinishedLen; } - c->verset = 1; return fprint(c->ctl, "version 0x%x", version); } @@ -2213,8 +2114,10 @@ finishedMatch(TlsConnection *c, Finished *f) static void tlsConnectionFree(TlsConnection *c) { - tlsSecClose(c->sec); - freebytes(c->sid); + if(c == nil) + return; + factotum_rsa_close(c->sec->rpc); + rsapubfree(c->sec->rsapub); freebytes(c->cert); memset(c, 0, sizeof(*c)); free(c); @@ -2528,118 +2431,168 @@ tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1); } -/* - * for setting server session id's - */ -static Lock sidLock; -static long maxSid = 1; +static void +sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) +{ + uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26]; + DigestState *s; + int i, n, len; + + USED(label); + len = 1; + while(nbuf > 0){ + if(len > 26) + return; + for(i = 0; i < len; i++) + tmp[i] = 'A' - 1 + len; + s = sha1(tmp, len, nil, nil); + s = sha1(key, nkey, nil, s); + s = sha1(seed0, nseed0, nil, s); + sha1(seed1, nseed1, sha1dig, s); + s = md5(key, nkey, nil, nil); + md5(sha1dig, SHA1dlen, md5dig, s); + n = MD5dlen; + if(n > nbuf) + n = nbuf; + memmove(buf, md5dig, n); + buf += n; + nbuf -= n; + len++; + } +} + +static void +sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient) +{ + DigestState *s; + uchar h0[MD5dlen], h1[SHA1dlen], pad[48]; + char *label; + + if(isclient) + label = "CLNT"; + else + label = "SRVR"; + + md5((uchar*)label, 4, nil, &hsh.md5); + md5(sec->sec, MasterSecretSize, nil, &hsh.md5); + memset(pad, 0x36, 48); + md5(pad, 48, nil, &hsh.md5); + md5(nil, 0, h0, &hsh.md5); + memset(pad, 0x5C, 48); + s = md5(sec->sec, MasterSecretSize, nil, nil); + s = md5(pad, 48, nil, s); + md5(h0, MD5dlen, finished, s); + + sha1((uchar*)label, 4, nil, &hsh.sha1); + sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1); + memset(pad, 0x36, 40); + sha1(pad, 40, nil, &hsh.sha1); + sha1(nil, 0, h1, &hsh.sha1); + memset(pad, 0x5C, 40); + s = sha1(sec->sec, MasterSecretSize, nil, nil); + s = sha1(pad, 40, nil, s); + sha1(h1, SHA1dlen, finished + MD5dlen, s); +} + +// fill "finished" arg with md5(args)^sha1(args) +static void +tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient) +{ + uchar h0[MD5dlen], h1[SHA1dlen]; + char *label; + + // get current hash value, but allow further messages to be hashed in + md5(nil, 0, h0, &hsh.md5); + sha1(nil, 0, h1, &hsh.sha1); + + if(isclient) + label = "client finished"; + else + label = "server finished"; + tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen); +} + +static void +tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isclient) +{ + uchar seed[SHA2_256dlen]; + char *label; + + // get current hash value, but allow further messages to be hashed in + sha2_256(nil, 0, seed, &hsh.sha2_256); + + if(isclient) + label = "client finished"; + else + label = "server finished"; + p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen); +} /* the keys are verified to have the same public components * and to function correctly with pkcs 1 encryption and decryption. */ -static TlsSec* -tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom) +static void +tlsSecInits(TlsSec *sec, int cvers, uchar *crandom) { - TlsSec *sec = emalloc(sizeof(*sec)); - - USED(csid); USED(ncsid); // ignore csid for now - - memmove(sec->crandom, crandom, RandomSize); + memset(sec, 0, sizeof(*sec)); sec->clientVers = cvers; + memmove(sec->crandom, crandom, RandomSize); put32(sec->srandom, time(nil)); genrandom(sec->srandom+4, RandomSize-4); - memmove(srandom, sec->srandom, RandomSize); - - /* - * make up a unique sid: use our pid, and and incrementing id - * can signal no sid by setting nssid to 0. - */ - memset(ssid, 0, SidSize); - put32(ssid, getpid()); - lock(&sidLock); - put32(ssid+4, maxSid++); - unlock(&sidLock); - *nssid = SidSize; - return sec; } static int -tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm) +tlsSecRSAs(TlsSec *sec, Bytes *epm) { Bytes *pm; - if(setVers(sec, vers) < 0) - goto Err; if(epm == nil){ werrstr("no encrypted premaster secret"); - goto Err; + return -1; } // if the client messed up, just continue as if everything is ok, // to prevent attacks to check for correctly formatted messages. pm = pkcs1_decrypt(sec, epm); - if(sec->ok < 0 || pm == nil || pm->len != MasterSecretSize || get16(pm->data) != sec->clientVers){ - sec->ok = -1; + if(pm == nil || pm->len != MasterSecretSize || get16(pm->data) != sec->clientVers){ freebytes(pm); pm = newbytes(MasterSecretSize); genrandom(pm->data, pm->len); } setMasterSecret(sec, pm); return 0; -Err: - sec->ok = -1; - return -1; } -static int -tlsSecPSKs(TlsSec *sec, int vers) +static void +tlsSecPSKs(TlsSec *sec) { - if(setVers(sec, vers) < 0){ - sec->ok = -1; - return -1; - } setMasterSecret(sec, newbytes(sec->psklen)); - return 0; } -static TlsSec* -tlsSecInitc(int cvers, uchar *crandom) +static void +tlsSecInitc(TlsSec *sec, int cvers) { - TlsSec *sec = emalloc(sizeof(*sec)); + memset(sec, 0, sizeof(*sec)); sec->clientVers = cvers; put32(sec->crandom, time(nil)); genrandom(sec->crandom+4, RandomSize-4); - memmove(crandom, sec->crandom, RandomSize); - return sec; } -static int -tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers) +static void +tlsSecPSKc(TlsSec *sec) { - memmove(sec->srandom, srandom, RandomSize); - if(setVers(sec, vers) < 0){ - sec->ok = -1; - return -1; - } setMasterSecret(sec, newbytes(sec->psklen)); - return 0; } static Bytes* -tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers) +tlsSecRSAc(TlsSec *sec, uchar *cert, int ncert) { RSApub *pub; Bytes *pm, *epm; - USED(sid); - USED(nsid); - - memmove(sec->srandom, srandom, RandomSize); - if(setVers(sec, vers) < 0) - goto Err; pub = X509toRSApub(cert, ncert, nil, 0); if(pub == nil){ werrstr("invalid x509/rsa certificate"); - goto Err; + return nil; } pm = newbytes(MasterSecretSize); put16(pm->data, sec->clientVers); @@ -2647,18 +2600,13 @@ tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int n epm = pkcs1_encrypt(pm, pub, 2); setMasterSecret(sec, pm); rsapubfree(pub); - if(epm != nil) - return epm; -Err: - sec->ok = -1; - return nil; + return epm; } static int tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient) { if(sec->nfin != nfin){ - sec->ok = -1; werrstr("invalid finished exchange"); return -1; } @@ -2666,29 +2614,11 @@ tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclien hsh.sha1.malloced = 0; hsh.sha2_256.malloced = 0; (*sec->setFinished)(sec, hsh, fin, isclient); - return 1; + return 0; } static void -tlsSecOk(TlsSec *sec) -{ - if(sec->ok == 0) - sec->ok = 1; -} - -static void -tlsSecClose(TlsSec *sec) -{ - if(sec == nil) - return; - factotum_rsa_close(sec->rpc); - rsapubfree(sec->rsapub); - free(sec->server); - free(sec); -} - -static int -setVers(TlsSec *sec, int v) +tlsSecVers(TlsSec *sec, int v) { if(v == SSL3Version){ sec->setFinished = sslSetFinished; @@ -2703,22 +2633,36 @@ setVers(TlsSec *sec, int v) sec->nfin = TLSFinishedLen; sec->prf = tls12PRF; } - sec->vers = v; - return 0; } -/* - * generate secret keys from the master secret. - * - * different crypto selections will require different amounts - * of key expansion and use of key expansion data, - * but it's all generated using the same function. - */ -static void -setSecrets(TlsSec *sec, uchar *kd, int nkd) +static int +setSecrets(TlsConnection *c, int isclient) { - (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion", - sec->srandom, RandomSize, sec->crandom, RandomSize); + uchar kd[MaxKeyData]; + char *secrets; + int rv; + + assert(c->nsecret <= sizeof(kd)); + secrets = emalloc(2*c->nsecret); + + /* + * generate secret keys from the master secret. + * + * different cipher selections will require different amounts + * of key expansion and use of key expansion data, + * but it's all generated using the same function. + */ + (*c->sec->prf)(kd, c->nsecret, c->sec->sec, MasterSecretSize, "key expansion", + c->sec->srandom, RandomSize, c->sec->crandom, RandomSize); + + enc64(secrets, 2*c->nsecret, kd, c->nsecret); + memset(kd, 0, c->nsecret); + + rv = fprint(c->ctl, "secret %s %s %d %s", c->digest, c->enc, isclient, secrets); + memset(secrets, 0, 2*c->nsecret); + free(secrets); + + return rv; } /* @@ -2751,103 +2695,6 @@ setMasterSecret(TlsSec *sec, Bytes *pm) freebytes(pm); } -static void -sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient) -{ - DigestState *s; - uchar h0[MD5dlen], h1[SHA1dlen], pad[48]; - char *label; - - if(isClient) - label = "CLNT"; - else - label = "SRVR"; - - md5((uchar*)label, 4, nil, &hsh.md5); - md5(sec->sec, MasterSecretSize, nil, &hsh.md5); - memset(pad, 0x36, 48); - md5(pad, 48, nil, &hsh.md5); - md5(nil, 0, h0, &hsh.md5); - memset(pad, 0x5C, 48); - s = md5(sec->sec, MasterSecretSize, nil, nil); - s = md5(pad, 48, nil, s); - md5(h0, MD5dlen, finished, s); - - sha1((uchar*)label, 4, nil, &hsh.sha1); - sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1); - memset(pad, 0x36, 40); - sha1(pad, 40, nil, &hsh.sha1); - sha1(nil, 0, h1, &hsh.sha1); - memset(pad, 0x5C, 40); - s = sha1(sec->sec, MasterSecretSize, nil, nil); - s = sha1(pad, 40, nil, s); - sha1(h1, SHA1dlen, finished + MD5dlen, s); -} - -// fill "finished" arg with md5(args)^sha1(args) -static void -tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient) -{ - uchar h0[MD5dlen], h1[SHA1dlen]; - char *label; - - // get current hash value, but allow further messages to be hashed in - md5(nil, 0, h0, &hsh.md5); - sha1(nil, 0, h1, &hsh.sha1); - - if(isClient) - label = "client finished"; - else - label = "server finished"; - tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen); -} - -static void -tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient) -{ - uchar seed[SHA2_256dlen]; - char *label; - - // get current hash value, but allow further messages to be hashed in - sha2_256(nil, 0, seed, &hsh.sha2_256); - - if(isClient) - label = "client finished"; - else - label = "server finished"; - p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen); -} - -static void -sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) -{ - uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26]; - DigestState *s; - int i, n, len; - - USED(label); - len = 1; - while(nbuf > 0){ - if(len > 26) - return; - for(i = 0; i < len; i++) - tmp[i] = 'A' - 1 + len; - s = sha1(tmp, len, nil, nil); - s = sha1(key, nkey, nil, s); - s = sha1(seed0, nseed0, nil, s); - sha1(seed1, nseed1, sha1dig, s); - s = md5(key, nkey, nil, nil); - md5(sha1dig, SHA1dlen, md5dig, s); - n = MD5dlen; - if(n > nbuf) - n = nbuf; - memmove(buf, md5dig, n); - buf += n; - nbuf -= n; - len++; - } -} - static mpint* bytestomp(Bytes* bytes) {