diff --git a/sys/src/cmd/ssh.c b/sys/src/cmd/ssh.c index 34b53099e..b85e6ec48 100644 --- a/sys/src/cmd/ssh.c +++ b/sys/src/cmd/ssh.c @@ -44,6 +44,7 @@ enum { typedef struct { + int pid; u32int seq; u32int kex; Chachastate cs1; @@ -59,19 +60,18 @@ typedef struct int nsid; uchar sid[256]; -int fd, pid1, pid2, intr, raw, debug; +int fd, intr, raw, debug; char *user, *status, *host, *cmd; Oneway recv, send; +void dispatch(void); void shutdown(void) { - int pid = getpid(); - if(pid1 && pid1 != pid) - postnote(PNPROC, pid1, "shutdown"); - if(pid2 && pid2 != pid) - postnote(PNPROC, pid2, "shutdown"); + recv.eof = send.eof = 1; + if(send.pid > 0) + postnote(PNPROC, send.pid, "shutdown"); } void @@ -353,35 +353,6 @@ if(debug > 1) return recv.r[0]; } -void -unexpected(char *info) -{ - char *s; - int n, c; - - switch(recv.r[0]){ - case MSG_DISCONNECT: - if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) - break; - sysfatal("disconnect: (%d) %.*s", c, n, s); - break; - case MSG_IGNORE: - case MSG_GLOBAL_REQUEST: - return; - case MSG_DEBUG: - if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0) - break; - if(c != 0) fprint(2, "%s: %.*s\n", argv0, n, s); - return; - case MSG_USERAUTH_BANNER: - if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0) - break; - if(raw) write(2, s, n); - return; - } - sysfatal("%s got: %.*H", info, (int)(recv.w - recv.r), recv.r); -} - static char sshrsa[] = "ssh-rsa"; int @@ -538,7 +509,7 @@ kex(int gotkexinit) if(!gotkexinit){ Next0: switch(recvpkt()){ default: - unexpected("KEXINIT"); + dispatch(); goto Next0; case MSG_KEXINIT: break; @@ -570,8 +541,10 @@ kex(int gotkexinit) sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc)); Next1: switch(recvpkt()){ default: - unexpected("ECDH_INIT"); + dispatch(); goto Next1; + case MSG_KEXINIT: + sysfatal("inception"); case MSG_ECDH_REPLY: if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0) sysfatal("bad ECDH_REPLY"); @@ -607,8 +580,10 @@ Next1: switch(recvpkt()){ sendpkt("b", MSG_NEWKEYS); Next2: switch(recvpkt()){ default: - unexpected("NEWKEYS"); + dispatch(); goto Next2; + case MSG_KEXINIT: + sysfatal("inception"); case MSG_NEWKEYS: break; } @@ -647,7 +622,7 @@ auth(char *username, char *servicename) sendpkt("bs", MSG_SERVICE_REQUEST, sshuserauth, sizeof(sshuserauth)-1); Next0: switch(recvpkt()){ default: - unexpected("SERVICE_REQUEST"); + dispatch(); goto Next0; case MSG_SERVICE_ACCEPT: break; @@ -690,7 +665,7 @@ Next0: switch(recvpkt()){ pk, npk); Next1: switch(recvpkt()){ default: - unexpected("USERAUTH_REQUEST"); + dispatch(); goto Next1; case MSG_USERAUTH_FAILURE: continue; @@ -733,7 +708,7 @@ Next1: switch(recvpkt()){ sig, nsig); Next2: switch(recvpkt()){ default: - unexpected("USERAUTH_REQUEST"); + dispatch(); goto Next2; case MSG_USERAUTH_FAILURE: continue; @@ -751,6 +726,83 @@ Next2: switch(recvpkt()){ return -1; } +void +dispatch(void) +{ + char *s; + uchar *p; + int n, b, c; + + switch(recv.r[0]){ + case MSG_IGNORE: + case MSG_GLOBAL_REQUEST: + case MSG_CHANNEL_WINDOW_ADJUST: + return; + case MSG_DISCONNECT: + if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) + break; + sysfatal("disconnect: (%d) %.*s", c, n, s); + return; + case MSG_DEBUG: + if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0) + break; + if(c != 0 || debug) fprint(2, "%s: %.*s\n", argv0, n, s); + return; + case MSG_USERAUTH_BANNER: + if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0) + break; + if(raw) write(2, s, n); + return; + case MSG_CHANNEL_DATA: + if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) + break; + if(c != 0) + break; + if(write(1, s, n) != n) + sysfatal("write out: %r"); + Winadjust: + sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n); + return; + case MSG_CHANNEL_EXTENDED_DATA: + if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0) + break; + if(c != 0) + break; + if(b == 1) write(2, s, n); + goto Winadjust; + case MSG_CHANNEL_REQUEST: + if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0) + break; + if(c != 0) + break; + if(n == 11 && memcmp(s, "exit-signal", n) == 0){ + if(unpack(p, recv.w-p, "s", &s, &n) < 0) + break; + if(n != 0 && status == nil) + status = smprint("%.*s", n, s); + } else if(n == 11 && memcmp(s, "exit-status", n) == 0){ + if(unpack(p, recv.w-p, "u", &n) < 0) + break; + if(n != 0 && status == nil) + status = smprint("%d", n); + } else if(debug) { + fprint(2, "%s: channel request: %.*s\n", argv0, n, s); + } + return; + case MSG_CHANNEL_EOF: + recv.eof = 1; + if(!raw) write(1, "", 0); + return; + case MSG_CHANNEL_CLOSE: + shutdown(); + return; + case MSG_KEXINIT: + kex(1); + return; + } + sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r); +} + char* readline(void) { @@ -830,7 +882,6 @@ main(int argc, char *argv[]) static QLock sl; int b, n, c; char *s; - uchar *p; quotefmtinstall(); fmtinstall('B', mpfmt); @@ -889,7 +940,6 @@ main(int argc, char *argv[]) recv.v = strdup(recv.v); kex(0); - if(user == nil) user = getuser(); if(auth(user, "ssh-connection") < 0) @@ -902,125 +952,92 @@ main(int argc, char *argv[]) sizeof(buf), sizeof(buf)); - while((send.eof | recv.eof) == 0){ - if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0){ - qlock(&sl); - kex(0); - qunlock(&sl); - } - switch(recvpkt()){ - default: - unexpected("CHANNEL"); - continue; - case MSG_KEXINIT: - qlock(&sl); - kex(1); - qunlock(&sl); - continue; - case MSG_CHANNEL_WINDOW_ADJUST: - continue; - case MSG_CHANNEL_EXTENDED_DATA: - if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0) - unexpected("CHANNEL_EXTENDED_DATA"); - if(b == 1) write(2, s, n); - sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n); - continue; - case MSG_CHANNEL_DATA: - if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) - unexpected("CHANNEL_DATA"); - write(1, s, n); - sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n); - continue; - case MSG_CHANNEL_EOF: - recv.eof = 1; - if(!raw) write(1, "", 0); - continue; - case MSG_CHANNEL_OPEN_FAILURE: - if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0) - unexpected("CHANNEL_OPEN_FAILURE"); - sysfatal("channel open failure: (%d) %.*s", b, n, s); - break; - case MSG_CHANNEL_OPEN_CONFIRMATION: - if(raw) { - rawon(); - sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST, - 0, - "pty-req", 7, - 0, - tty.term, strlen(tty.term), - tty.cols, - tty.lines, - tty.xpixels, - tty.ypixels, - "", 0); - } - if(cmd == nil){ - sendpkt("busb", MSG_CHANNEL_REQUEST, - 0, - "shell", 5, - 0); - } else { - sendpkt("busbs", MSG_CHANNEL_REQUEST, - 0, - "exec", 4, - 0, - cmd, strlen(cmd)); - } - if(pid2) - continue; - pid1 = getpid(); - notify(catch); - atexit(shutdown); - n = rfork(RFPROC|RFMEM); - if(n){ - pid2 = n; - continue; - } - qlock(&sl); - for(;;){ - qunlock(&sl); - n = read(0, buf, sizeof(buf)); - qlock(&sl); - if(n < 0 && wasintr()){ - sendpkt("busbs", MSG_CHANNEL_REQUEST, - 0, - "signal", 6, - 0, - "INT", 3); - intr = 0; - continue; - } - if(n <= 0) - break; - sendpkt("bus", MSG_CHANNEL_DATA, - 0, - buf, n); - } - send.eof = 1; - sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0); - qunlock(&sl); - break; - case MSG_CHANNEL_REQUEST: - if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0) - unexpected("CHANNEL_REQUEST"); - if(n == 11 && memcmp(s, "exit-signal", n) == 0){ - if(unpack(p, recv.w-p, "s", &s, &n) < 0) - continue; - if(n != 0 && status == nil) - status = smprint("%.*s", n, s); - } else if(n == 11 && memcmp(s, "exit-status", n) == 0){ - if(unpack(p, recv.w-p, "u", &n) < 0) - continue; - if(n != 0 && status == nil) - status = smprint("%d", n); - } else { - fprint(2, "%s: channel request: %.*s\n", argv0, n, s); - } - continue; - case MSG_CHANNEL_CLOSE: - break; - } +Next0: switch(recvpkt()){ + default: + dispatch(); + goto Next0; + case MSG_CHANNEL_OPEN_FAILURE: + if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0) + n = strlen(s = "???"); + sysfatal("channel open failure: (%d) %.*s", b, n, s); + case MSG_CHANNEL_OPEN_CONFIRMATION: break; } - exits(status); + + notify(catch); + atexit(shutdown); + + recv.pid = getpid(); + n = rfork(RFPROC|RFMEM); + if(n < 0) + sysfatal("fork: %r"); + + /* parent reads and dispatches packets */ + if(n > 0) { + send.pid = n; + while((send.eof|recv.eof) == 0){ + recvpkt(); + qlock(&sl); + dispatch(); + if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0) + kex(0); + qunlock(&sl); + } + exits(status); + } + + /* child reads input and sends packets */ + qlock(&sl); + if(raw) { + rawon(); + sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST, + 0, + "pty-req", 7, + 0, + tty.term, strlen(tty.term), + tty.cols, + tty.lines, + tty.xpixels, + tty.ypixels, + "", 0); + } + if(cmd == nil){ + sendpkt("busb", MSG_CHANNEL_REQUEST, + 0, + "shell", 5, + 0); + } else { + sendpkt("busbs", MSG_CHANNEL_REQUEST, + 0, + "exec", 4, + 0, + cmd, strlen(cmd)); + } + for(;;){ + qunlock(&sl); + n = read(0, buf, sizeof(buf)); + qlock(&sl); + if(send.eof) + break; + if(n < 0 && wasintr()){ + if(!raw) break; + sendpkt("busbs", MSG_CHANNEL_REQUEST, + 0, + "signal", 6, + 0, + "INT", 3); + intr = 0; + continue; + } + if(n <= 0) + break; + sendpkt("bus", MSG_CHANNEL_DATA, + 0, + buf, n); + } + if(send.eof++ == 0) + sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0); + qunlock(&sl); + + exits(nil); }