diff --git a/sys/src/libsec/port/tlshand.c b/sys/src/libsec/port/tlshand.c index 13baa6a62..a315f39da 100644 --- a/sys/src/libsec/port/tlshand.c +++ b/sys/src/libsec/port/tlshand.c @@ -63,6 +63,12 @@ typedef struct Finished{ int n; } Finished; +typedef struct HandshakeHash { + MD5state md5; + SHAstate sha1; + SHA2_256state sha2_256; +} HandshakeHash; + typedef struct TlsConnection{ TlsSec *sec; // security management goo int hand, ctl; // record layer file descriptors @@ -95,8 +101,7 @@ typedef struct TlsConnection{ int nsecret; // amount of secret data to init keys // for finished messages - MD5state hsmd5; // handshake hash - SHAstate hssha1; // handshake hash + HandshakeHash handhash; Finished finished; } TlsConnection; @@ -157,7 +162,7 @@ typedef struct TlsSec{ int vers; // final version // byte generation and handshake checksum void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int); - void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int); + void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int); int nfin; } TlsSec; @@ -166,7 +171,8 @@ enum { SSL3Version = 0x0300, TLS10Version = 0x0301, TLS11Version = 0x0302, - ProtocolVersion = TLS11Version, // maximum version we speak + TLS12Version = 0x0303, + ProtocolVersion = TLS11Version, // maximum version we speak (server) MinProtoVersion = 0x0300, // limits on version we accept MaxProtoVersion = 0x03ff, }; @@ -331,7 +337,7 @@ static TlsSec* tlsSecInitc(int cvers, uchar *crandom); static Bytes* tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers); static Bytes* tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys); static Bytes* tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys); -static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient); +static int tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient); static void tlsSecOk(TlsSec *sec); static void tlsSecKill(TlsSec *sec); static void tlsSecClose(TlsSec *sec); @@ -341,8 +347,9 @@ static void setSecrets(TlsSec *sec, uchar *kd, int nkd); static Bytes* clientMasterSecret(TlsSec *sec, RSApub *pub); static Bytes* pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype); static Bytes* pkcs1_decrypt(TlsSec *sec, Bytes *cipher); -static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient); -static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient); +static void tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient); +static void tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient); +static void sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient); static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1); static int setVers(TlsSec *sec, int version); @@ -693,7 +700,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, . msgClear(&m); /* no CertificateVerify; skip to Finished */ - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){ + if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){ tlsError(c, EInternalError, "can't set finished: %r"); goto Err; } @@ -715,7 +722,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, . goto Err; } - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){ + if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){ tlsError(c, EInternalError, "can't set finished: %r"); goto Err; } @@ -961,7 +968,12 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, return nil; epm = nil; c = emalloc(sizeof(TlsConnection)); - c->version = ProtocolVersion; + c->version = TLS12Version; + + // client certificate signature not implemented for TLS1.2 + if(cert != nil && certlen > 0) + c->version = TLS11Version; + c->ctl = ctl; c->hand = hand; c->trace = trace; @@ -1114,25 +1126,16 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, goto Err; msgClear(&m); - /* CertificateVerify */ - /*XXX I should only send this when it is not DH right? - Also we need to know which TLS key - we have to use in case there are more than one*/ - if(cert){ - m.tag = HCertificateVerify; + /* certificate verify */ + if(creq && cert != nil && certlen > 0) { uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */ - MD5state hsmd5_save; - SHAstate hssha1_save; - - /* save the state for the Finish message */ + HandshakeHash hsave; - hsmd5_save = c->hsmd5; - hssha1_save = c->hssha1; - md5(nil, 0, hshashes, &c->hsmd5); - sha1(nil, 0, hshashes+MD5dlen, &c->hssha1); - - c->hsmd5 = hsmd5_save; - c->hssha1 = hssha1_save; + /* save the state for the Finish message */ + hsave = c->handhash; + md5(nil, 0, hshashes, &c->handhash.md5); + sha1(nil, 0, hshashes+MD5dlen, &c->handhash.sha1); + c->handhash = hsave; c->sec->rpc = factotum_rsa_open(cert, certlen); if(c->sec->rpc == nil){ @@ -1154,6 +1157,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, m.u.certificateVerify.signature = mptobytes(signedMP); mpfree(signedMP); + m.tag = HCertificateVerify; if(!msgSend(c, &m, AFlush)) goto Err; msgClear(&m); @@ -1167,7 +1171,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, // Cipherchange must occur immediately before Finished to avoid // potential hole; see section 4.3 of Wagner Schneier 1996. - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){ + if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){ tlsError(c, EInternalError, "can't set finished 1: %r"); goto Err; } @@ -1179,7 +1183,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, } msgClear(&m); - if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){ + if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){ tlsError(c, EInternalError, "can't set finished 0: %r"); goto Err; } @@ -1216,6 +1220,15 @@ Err: //================= message functions ======================== +static void +msgHash(TlsConnection *c, uchar *p, int n) +{ + md5(p, n, 0, &c->handhash.md5); + sha1(p, n, 0, &c->handhash.sha1); + if(c->version >= TLS12Version) + sha2_256(p, n, 0, &c->handhash.sha2_256); +} + static int msgSend(TlsConnection *c, Msg *m, int act) { @@ -1352,10 +1365,8 @@ msgSend(TlsConnection *c, Msg *m, int act) put24(c->sendp+1, n-4); // remember hash of Handshake messages - if(m->tag != HHelloRequest) { - md5(c->sendp, n, 0, &c->hsmd5); - sha1(c->sendp, n, 0, &c->hssha1); - } + if(m->tag != HHelloRequest) + msgHash(c, c->sendp, n); c->sendp = p; if(act == AFlush){ @@ -1430,8 +1441,7 @@ msgRecv(TlsConnection *c, Msg *m) p = tlsReadN(c, n); if(p == nil) return 0; - md5(p, n, 0, &c->hsmd5); - sha1(p, n, 0, &c->hssha1); + msgHash(c, p, n); m->tag = HClientHello; if(n < 22) goto Short; @@ -1468,15 +1478,13 @@ msgRecv(TlsConnection *c, Msg *m) m->u.clientHello.compressors->data[0] = CompressionNull; goto Ok; } - md5(p, 4, 0, &c->hsmd5); - sha1(p, 4, 0, &c->hssha1); + msgHash(c, p, 4); p = tlsReadN(c, n); if(p == nil) return 0; - md5(p, n, 0, &c->hsmd5); - sha1(p, n, 0, &c->hssha1); + msgHash(c, p, n); m->tag = type; @@ -1678,6 +1686,12 @@ msgRecv(TlsConnection *c, Msg *m) break; } if(n >= 2){ + if(c->version >= TLS12Version){ + /* signature hash algorithm */ + p += 2, n -= 2; + if(n < 2) + goto Short; + } nn = get16(p); p += 2, n -= 2; if(nn > 0 && nn <= n){ @@ -2265,20 +2279,55 @@ tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, u } } +static void +p_sha256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed) +{ + uchar ai[SHA2_256dlen], tmp[SHA2_256dlen]; + SHAstate *s; + int n; + + // generate a1 + s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil); + hmac_sha2_256(seed, nseed, key, nkey, ai, s); + + while(nbuf > 0) { + s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil); + s = hmac_sha2_256(label, nlabel, key, nkey, nil, s); + hmac_sha2_256(seed, nseed, key, nkey, tmp, s); + n = SHA2_256dlen; + if(n > nbuf) + n = nbuf; + memmove(buf, tmp, n); + buf += n; + nbuf -= n; + hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil); + memmove(ai, tmp, SHA2_256dlen); + } +} + // fill buf with md5(args)^sha1(args) static void -tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) +tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) { - int i; int nlabel = strlen(label); int n = (nkey + 1) >> 1; - for(i = 0; i < nbuf; i++) - buf[i] = 0; + memset(buf, 0, nbuf); tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1); tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1); } +static void +tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) +{ + uchar seed[2*RandomSize]; + + assert(nseed0+nseed1 <= sizeof(seed)); + memmove(seed, seed0, nseed0); + memmove(seed+nseed0, seed1, nseed1); + p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1); +} + /* * for setting server session id's */ @@ -2369,16 +2418,17 @@ Err: } static int -tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient) +tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient) { if(sec->nfin != nfin){ sec->ok = -1; werrstr("invalid finished exchange"); return -1; } - md5.malloced = 0; - sha1.malloced = 0; - (*sec->setFinished)(sec, md5, sha1, fin, isclient); + hsh.md5.malloced = 0; + hsh.sha1.malloced = 0; + hsh.sha2_256.malloced = 0; + (*sec->setFinished)(sec, hsh, fin, isclient); return 1; } @@ -2415,10 +2465,14 @@ setVers(TlsSec *sec, int v) sec->setFinished = sslSetFinished; sec->nfin = SSL3FinishedLen; sec->prf = sslPRF; - }else{ - sec->setFinished = tlsSetFinished; + }else if(v < TLS12Version) { + sec->setFinished = tls10SetFinished; sec->nfin = TLSFinishedLen; - sec->prf = tlsPRF; + sec->prf = tls10PRF; + }else { + sec->setFinished = tls12SetFinished; + sec->nfin = TLSFinishedLen; + sec->prf = tls12PRF; } sec->vers = v; return 0; @@ -2488,7 +2542,7 @@ clientMasterSecret(TlsSec *sec, RSApub *pub) } static void -sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient) +sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient) { DigestState *s; uchar h0[MD5dlen], h1[SHA1dlen], pad[48]; @@ -2499,21 +2553,21 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in else label = "SRVR"; - md5((uchar*)label, 4, nil, &hsmd5); - md5(sec->sec, MasterSecretSize, nil, &hsmd5); + md5((uchar*)label, 4, nil, &hsh.md5); + md5(sec->sec, MasterSecretSize, nil, &hsh.md5); memset(pad, 0x36, 48); - md5(pad, 48, nil, &hsmd5); - md5(nil, 0, h0, &hsmd5); + md5(pad, 48, nil, &hsh.md5); + md5(nil, 0, h0, &hsh.md5); memset(pad, 0x5C, 48); s = md5(sec->sec, MasterSecretSize, nil, nil); s = md5(pad, 48, nil, s); md5(h0, MD5dlen, finished, s); - sha1((uchar*)label, 4, nil, &hssha1); - sha1(sec->sec, MasterSecretSize, nil, &hssha1); + sha1((uchar*)label, 4, nil, &hsh.sha1); + sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1); memset(pad, 0x36, 40); - sha1(pad, 40, nil, &hssha1); - sha1(nil, 0, h1, &hssha1); + sha1(pad, 40, nil, &hsh.sha1); + sha1(nil, 0, h1, &hsh.sha1); memset(pad, 0x5C, 40); s = sha1(sec->sec, MasterSecretSize, nil, nil); s = sha1(pad, 40, nil, s); @@ -2522,27 +2576,43 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in // fill "finished" arg with md5(args)^sha1(args) static void -tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient) +tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient) { uchar h0[MD5dlen], h1[SHA1dlen]; char *label; // get current hash value, but allow further messages to be hashed in - md5(nil, 0, h0, &hsmd5); - sha1(nil, 0, h1, &hssha1); + md5(nil, 0, h0, &hsh.md5); + sha1(nil, 0, h1, &hsh.sha1); if(isClient) label = "client finished"; else label = "server finished"; - tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen); + tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen); +} + +static void +tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient) +{ + uchar seed[SHA2_256dlen]; + char *label; + + // get current hash value, but allow further messages to be hashed in + sha2_256(nil, 0, seed, &hsh.sha2_256); + + if(isClient) + label = "client finished"; + else + label = "server finished"; + p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen); } static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1) { - DigestState *s; uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26]; + DigestState *s; int i, n, len; USED(label);