From 1160f6c9bf634c17f48e7777939fb876fa1538db Mon Sep 17 00:00:00 2001 From: William Pitcock Date: Sat, 2 Apr 2016 17:05:40 -0500 Subject: [PATCH] wsockd: implement websocket handshake part --- wsockd/Makefile.am | 2 +- wsockd/sha1.c | 143 +++++++++++++++++++++++++++++++++++++++++++++ wsockd/sha1.h | 25 ++++++++ wsockd/wsockd.c | 135 +++++++++++++++++++++++++++++++++++++++--- 4 files changed, 295 insertions(+), 10 deletions(-) create mode 100644 wsockd/sha1.c create mode 100644 wsockd/sha1.h diff --git a/wsockd/Makefile.am b/wsockd/Makefile.am index d4524698..706600ea 100644 --- a/wsockd/Makefile.am +++ b/wsockd/Makefile.am @@ -3,5 +3,5 @@ AM_CFLAGS=$(WARNFLAGS) AM_CPPFLAGS = -I../include -I../librb/include -wsockd_SOURCES = wsockd.c +wsockd_SOURCES = wsockd.c sha1.c wsockd_LDADD = ../librb/src/librb.la diff --git a/wsockd/sha1.c b/wsockd/sha1.c new file mode 100644 index 00000000..3214951d --- /dev/null +++ b/wsockd/sha1.c @@ -0,0 +1,143 @@ +/* + * Based on the SHA-1 C implementation by Steve Reid + * 100% Public Domain + * + * Test Vectors (from FIPS PUB 180-1) + * "abc" + * A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D + * "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" + * 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1 + * A million repetitions of "a" + * 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F + */ + +#include +#ifdef _WIN32 + #include // for htonl() +#else + #include // for htonl() +#endif + +#include "sha1.h" + +#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) + +// blk0() and blk() perform the initial expand. blk0() deals with host endianess +#define blk0(i) (block[i] = htonl(block[i])) +#define blk(i) (block[i&15] = rol(block[(i+13)&15]^block[(i+8)&15]^block[(i+2)&15]^block[i&15],1)) + +// (R0+R1), R2, R3, R4 are the different operations (rounds) used in SHA1 +#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30); +#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30); +#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30); + +// hash a single 512-bit block. this is the core of the algorithm +static uint32_t sha1_transform(SHA1 *sha1, const uint8_t buffer[SHA1_BLOCK_LENGTH]) { + uint32_t a, b, c, d, e; + uint32_t block[SHA1_BLOCK_LENGTH / 4]; + + memcpy(&block, buffer, SHA1_BLOCK_LENGTH); + + // copy sha1->state[] to working variables + a = sha1->state[0]; + b = sha1->state[1]; + c = sha1->state[2]; + d = sha1->state[3]; + e = sha1->state[4]; + + // 4 rounds of 20 operations each (loop unrolled) + R0(a,b,c,d,e, 0); R0(e,a,b,c,d, 1); R0(d,e,a,b,c, 2); R0(c,d,e,a,b, 3); + R0(b,c,d,e,a, 4); R0(a,b,c,d,e, 5); R0(e,a,b,c,d, 6); R0(d,e,a,b,c, 7); + R0(c,d,e,a,b, 8); R0(b,c,d,e,a, 9); R0(a,b,c,d,e,10); R0(e,a,b,c,d,11); + R0(d,e,a,b,c,12); R0(c,d,e,a,b,13); R0(b,c,d,e,a,14); R0(a,b,c,d,e,15); + R1(e,a,b,c,d,16); R1(d,e,a,b,c,17); R1(c,d,e,a,b,18); R1(b,c,d,e,a,19); + + R2(a,b,c,d,e,20); R2(e,a,b,c,d,21); R2(d,e,a,b,c,22); R2(c,d,e,a,b,23); + R2(b,c,d,e,a,24); R2(a,b,c,d,e,25); R2(e,a,b,c,d,26); R2(d,e,a,b,c,27); + R2(c,d,e,a,b,28); R2(b,c,d,e,a,29); R2(a,b,c,d,e,30); R2(e,a,b,c,d,31); + R2(d,e,a,b,c,32); R2(c,d,e,a,b,33); R2(b,c,d,e,a,34); R2(a,b,c,d,e,35); + R2(e,a,b,c,d,36); R2(d,e,a,b,c,37); R2(c,d,e,a,b,38); R2(b,c,d,e,a,39); + + R3(a,b,c,d,e,40); R3(e,a,b,c,d,41); R3(d,e,a,b,c,42); R3(c,d,e,a,b,43); + R3(b,c,d,e,a,44); R3(a,b,c,d,e,45); R3(e,a,b,c,d,46); R3(d,e,a,b,c,47); + R3(c,d,e,a,b,48); R3(b,c,d,e,a,49); R3(a,b,c,d,e,50); R3(e,a,b,c,d,51); + R3(d,e,a,b,c,52); R3(c,d,e,a,b,53); R3(b,c,d,e,a,54); R3(a,b,c,d,e,55); + R3(e,a,b,c,d,56); R3(d,e,a,b,c,57); R3(c,d,e,a,b,58); R3(b,c,d,e,a,59); + + R4(a,b,c,d,e,60); R4(e,a,b,c,d,61); R4(d,e,a,b,c,62); R4(c,d,e,a,b,63); + R4(b,c,d,e,a,64); R4(a,b,c,d,e,65); R4(e,a,b,c,d,66); R4(d,e,a,b,c,67); + R4(c,d,e,a,b,68); R4(b,c,d,e,a,69); R4(a,b,c,d,e,70); R4(e,a,b,c,d,71); + R4(d,e,a,b,c,72); R4(c,d,e,a,b,73); R4(b,c,d,e,a,74); R4(a,b,c,d,e,75); + R4(e,a,b,c,d,76); R4(d,e,a,b,c,77); R4(c,d,e,a,b,78); R4(b,c,d,e,a,79); + + // add the working variables back into sha1->state[] + sha1->state[0] += a; + sha1->state[1] += b; + sha1->state[2] += c; + sha1->state[3] += d; + sha1->state[4] += e; + + // wipe variables + a = b = c = d = e = 0; + + return a; // return a to avoid dead-store warning from clang static analyzer +} + +void sha1_init(SHA1 *sha1) { + sha1->state[0] = 0x67452301; + sha1->state[1] = 0xEFCDAB89; + sha1->state[2] = 0x98BADCFE; + sha1->state[3] = 0x10325476; + sha1->state[4] = 0xC3D2E1F0; + sha1->count = 0; +} + +void sha1_update(SHA1 *sha1, const uint8_t *data, size_t length) { + size_t i, j; + + j = (size_t)((sha1->count >> 3) & 63); + sha1->count += (length << 3); + + if ((j + length) > 63) { + i = 64 - j; + + memcpy(&sha1->buffer[j], data, i); + sha1_transform(sha1, sha1->buffer); + + for (; i + 63 < length; i += 64) { + sha1_transform(sha1, &data[i]); + } + + j = 0; + } else { + i = 0; + } + + memcpy(&sha1->buffer[j], &data[i], length - i); +} + +void sha1_final(SHA1 *sha1, uint8_t digest[SHA1_DIGEST_LENGTH]) { + uint32_t i; + uint8_t count[8]; + + for (i = 0; i < 8; i++) { + // this is endian independent + count[i] = (uint8_t)((sha1->count >> ((7 - (i & 7)) * 8)) & 255); + } + + sha1_update(sha1, (uint8_t *)"\200", 1); + + while ((sha1->count & 504) != 448) { + sha1_update(sha1, (uint8_t *)"\0", 1); + } + + sha1_update(sha1, count, 8); + + for (i = 0; i < SHA1_DIGEST_LENGTH; i++) { + digest[i] = (uint8_t)((sha1->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255); + } + + memset(sha1, 0, sizeof(*sha1)); +} diff --git a/wsockd/sha1.h b/wsockd/sha1.h new file mode 100644 index 00000000..cdd7d082 --- /dev/null +++ b/wsockd/sha1.h @@ -0,0 +1,25 @@ +/* + * Based on the SHA-1 C implementation by Steve Reid + * 100% Public Domain + */ + +#ifndef SHA1_H +#define SHA1_H + +#include +#include + +#define SHA1_BLOCK_LENGTH 64 +#define SHA1_DIGEST_LENGTH 20 + +typedef struct { + uint32_t state[5]; + uint64_t count; + uint8_t buffer[SHA1_BLOCK_LENGTH]; +} SHA1; + +void sha1_init(SHA1 *sha1); +void sha1_update(SHA1 *sha1, const uint8_t *data, size_t length); +void sha1_final(SHA1 *sha1, uint8_t digest[SHA1_DIGEST_LENGTH]); + +#endif // SHA1_H diff --git a/wsockd/wsockd.c b/wsockd/wsockd.c index ab191d02..db1d3ddc 100644 --- a/wsockd/wsockd.c +++ b/wsockd/wsockd.c @@ -21,12 +21,17 @@ */ #include "stdinc.h" +#include "sha1.h" #define MAXPASSFD 4 #ifndef READBUF_SIZE #define READBUF_SIZE 16384 #endif +#define WEBSOCKET_SERVER_KEY "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +#define WEBSOCKET_ANSWER_STRING_1 "HTTP/1.1 101 Switching Protocols\r\nAccess-Control-Allow-Origin: *\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " +#define WEBSOCKET_ANSWER_STRING_2 "\r\n\r\n" + static void setup_signals(void); static pid_t ppid; @@ -71,8 +76,8 @@ typedef struct _conn rb_dlink_node node; mod_ctl_t *ctl; - buf_head_t modbuf_out; - buf_head_t modbuf_in; + rawbuf_head_t *modbuf_out; + rawbuf_head_t *modbuf_in; buf_head_t plainbuf_out; buf_head_t plainbuf_in; @@ -86,23 +91,33 @@ typedef struct _conn uint64_t plain_in; uint64_t plain_out; uint8_t flags; + + char client_key[37]; /* maximum 36 bytes + nul */ } conn_t; +static void close_conn(conn_t * conn, int wait_plain, const char *fmt, ...); +static void conn_mod_read_cb(rb_fde_t *fd, void *data); +static void conn_plain_read_cb(rb_fde_t *fd, void *data); + #define FLAG_CORK 0x01 #define FLAG_DEAD 0x02 #define FLAG_WSOCK 0x04 +#define FLAG_KEYED 0x08 #define IsCork(x) ((x)->flags & FLAG_CORK) #define IsDead(x) ((x)->flags & FLAG_DEAD) #define IsWS(x) ((x)->flags & FLAG_WSOCK) +#define IsKeyed(x) ((x)->flags & FLAG_KEYED) #define SetCork(x) ((x)->flags |= FLAG_CORK) #define SetDead(x) ((x)->flags |= FLAG_DEAD) #define SetWS(x) ((x)->flags |= FLAG_WSOCK) +#define SetKeyed(x) ((x)->flags |= FLAG_KEYED) #define ClearCork(x) ((x)->flags &= ~FLAG_CORK) #define ClearDead(x) ((x)->flags &= ~FLAG_DEAD) #define ClearWS(x) ((x)->flags &= ~FLAG_WSOCK) +#define ClearKeyed(x) ((x)->flags &= ~FLAG_KEYED) #define NO_WAIT 0x0 #define WAIT_PLAIN 0x1 @@ -112,6 +127,8 @@ typedef struct _conn #define CONN_HASH_SIZE 2000 #define connid_hash(x) (&connid_hash_table[(x % CONN_HASH_SIZE)]) +static const char *remote_closed = "Remote host closed the connection"; + static rb_dlink_list connid_hash_table[CONN_HASH_SIZE]; static rb_dlink_list dead_list; @@ -196,8 +213,8 @@ free_conn(conn_t * conn) rb_linebuf_donebuf(&conn->plainbuf_in); rb_linebuf_donebuf(&conn->plainbuf_out); - rb_linebuf_donebuf(&conn->modbuf_in); - rb_linebuf_donebuf(&conn->modbuf_out); + rb_free_rawbuffer(conn->modbuf_in); + rb_free_rawbuffer(conn->modbuf_out); rb_free(conn); } @@ -217,6 +234,56 @@ clean_dead_conns(void *unused) dead_list.tail = dead_list.head = NULL; } +static void +conn_mod_write_sendq(rb_fde_t *fd, void *data) +{ + conn_t *conn = data; + const char *err; + int retlen; + + if(IsDead(conn)) + return; + + while((retlen = rb_rawbuf_flush(conn->modbuf_out, fd)) > 0) + conn->mod_out += retlen; + + if(retlen == 0 || (retlen < 0 && !rb_ignore_errno(errno))) + { + if(retlen == 0) + close_conn(conn, WAIT_PLAIN, "%s", remote_closed); + err = strerror(errno); + close_conn(conn, WAIT_PLAIN, "Write error: %s", err); + return; + } + + if(rb_rawbuf_length(conn->modbuf_out) > 0) + rb_setselect(conn->mod_fd, RB_SELECT_WRITE, conn_mod_write_sendq, conn); + else + rb_setselect(conn->mod_fd, RB_SELECT_WRITE, NULL, NULL); + + if(IsCork(conn) && rb_rawbuf_length(conn->modbuf_out) == 0) + { + ClearCork(conn); + conn_plain_read_cb(conn->plain_fd, conn); + } +} + +static void +conn_mod_write(conn_t * conn, void *data, size_t len) +{ + if(IsDead(conn)) /* no point in queueing to a dead man */ + return; + rb_rawbuf_append(conn->modbuf_out, data, len); +} + +static void +conn_plain_write(conn_t * conn, void *data, size_t len) +{ + if(IsDead(conn)) /* again no point in queueing to dead men */ + return; + rb_linebuf_put(&conn->plainbuf_out, data, len); +} + static void mod_write_ctl(rb_fde_t *F, void *data) { @@ -270,7 +337,7 @@ close_conn(conn_t * conn, int wait_plain, const char *fmt, ...) if(IsDead(conn)) return; - rb_linebuf_flush(conn->mod_fd, &conn->modbuf_out); + rb_rawbuf_flush(conn->modbuf_out, conn->mod_fd); rb_linebuf_flush(conn->plain_fd, &conn->plainbuf_out); rb_close(conn->mod_fd); SetDead(conn); @@ -312,8 +379,8 @@ make_conn(mod_ctl_t * ctl, rb_fde_t *mod_fd, rb_fde_t *plain_fd) rb_linebuf_newbuf(&conn->plainbuf_in); rb_linebuf_newbuf(&conn->plainbuf_out); - rb_linebuf_newbuf(&conn->modbuf_in); - rb_linebuf_newbuf(&conn->modbuf_out); + conn->modbuf_in = rb_new_rawbuffer(); + conn->modbuf_out = rb_new_rawbuffer(); return conn; } @@ -335,10 +402,59 @@ conn_mod_handshake_process(conn_t *conn) while (1) { - size_t dolen = rb_linebuf_get(&conn->modbuf_in, inbuf, READBUF_SIZE, LINEBUF_COMPLETE, LINEBUF_PARSED); + char *p = NULL; + + size_t dolen = rb_rawbuf_get(conn->modbuf_in, inbuf, sizeof inbuf); if (!dolen) break; + + if ((p = strcasestr(inbuf, "Sec-WebSocket-Key:")) != NULL) + { + char *start, *end; + + start = p + strlen("Sec-WebSocket-Key:"); + + for (; start < (inbuf + READBUF_SIZE) && *start; start++) + { + if (*start != ' ' && *start != '\t') + break; + } + + for (end = start; end < (inbuf + READBUF_SIZE) && *end; end++) + { + if (*end == '\r' || *end == '\n') + { + *end = '\0'; + break; + } + } + + rb_strlcpy(conn->client_key, start, sizeof(conn->client_key)); + SetKeyed(conn); + } } + + if (IsKeyed(conn)) + { + SHA1 sha1; + uint8_t digest[SHA1_DIGEST_LENGTH]; + char *resp; + + sha1_init(&sha1); + sha1_update(&sha1, (uint8_t *) conn->client_key, strlen(conn->client_key)); + sha1_update(&sha1, (uint8_t *) WEBSOCKET_SERVER_KEY, strlen(WEBSOCKET_SERVER_KEY)); + sha1_final(&sha1, digest); + + resp = (char *) rb_base64_encode(digest, SHA1_DIGEST_LENGTH); + + conn_mod_write(conn, WEBSOCKET_ANSWER_STRING_1, strlen(WEBSOCKET_ANSWER_STRING_1)); + conn_mod_write(conn, resp, strlen(resp)); + conn_mod_write(conn, WEBSOCKET_ANSWER_STRING_2, strlen(WEBSOCKET_ANSWER_STRING_2)); + + rb_free(resp); + } + + conn_mod_write_sendq(conn->mod_fd, conn); } static void @@ -375,7 +491,7 @@ conn_mod_handshake_cb(rb_fde_t *fd, void *data) return; } - int res = rb_linebuf_parse(&conn->modbuf_in, inbuf, length, 0); + rb_rawbuf_append(conn->modbuf_in, inbuf, length); conn_mod_handshake_process(conn); if (length < sizeof(inbuf)) @@ -578,6 +694,7 @@ main(int argc, char **argv) setup_signals(); rb_lib_init(NULL, NULL, NULL, 0, maxfd, 1024, 4096); rb_linebuf_init(4096); + rb_init_rawbuffers(4096); mod_ctl = rb_malloc(sizeof(mod_ctl_t)); mod_ctl->F = rb_open(ctlfd, RB_FD_SOCKET, "ircd control socket");