ssh: actually handle flow control and channel id's

This commit is contained in:
cinap_lenrek 2017-04-21 19:23:56 +02:00
parent 99825e22ed
commit a944c37d68

View file

@ -45,19 +45,33 @@ enum {
MSG_CHANNEL_FAILURE, MSG_CHANNEL_FAILURE,
}; };
enum {
Overhead = 256, // enougth for MSG_CHANNEL_DATA header
MaxPacket = 1<<15,
WinPackets = 8, // (1<<15) * 8 = 256K
};
typedef struct typedef struct
{ {
int pid;
u32int seq; u32int seq;
u32int kex; u32int kex;
u32int chan;
int win;
int pkt;
int eof;
Chachastate cs1; Chachastate cs1;
Chachastate cs2; Chachastate cs2;
char *v;
char eof;
uchar *r; uchar *r;
uchar *w; uchar *w;
uchar b[1<<15]; uchar b[Overhead + MaxPacket];
char *v;
int pid;
Rendez;
} Oneway; } Oneway;
int nsid; int nsid;
@ -902,7 +916,6 @@ dispatch(void)
switch(recv.r[0]){ switch(recv.r[0]){
case MSG_IGNORE: case MSG_IGNORE:
case MSG_GLOBAL_REQUEST: case MSG_GLOBAL_REQUEST:
case MSG_CHANNEL_WINDOW_ADJUST:
return; return;
case MSG_DISCONNECT: case MSG_DISCONNECT:
if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
@ -922,24 +935,38 @@ dispatch(void)
case MSG_CHANNEL_DATA: case MSG_CHANNEL_DATA:
if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0) if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
break; break;
if(c != 0) if(c != recv.chan)
break; break;
if(write(1, s, n) != n) if(write(1, s, n) != n)
sysfatal("write out: %r"); sysfatal("write out: %r");
Winadjust: Winadjust:
sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, c, n); recv.win -= n;
if(recv.win < recv.pkt){
n = WinPackets*recv.pkt;
recv.win += n;
sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, send.chan, n);
}
return; return;
case MSG_CHANNEL_EXTENDED_DATA: case MSG_CHANNEL_EXTENDED_DATA:
if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0) if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
break; break;
if(c != 0) if(c != recv.chan)
break; break;
if(b == 1) write(2, s, n); if(b == 1) write(2, s, n);
goto Winadjust; goto Winadjust;
case MSG_CHANNEL_WINDOW_ADJUST:
if(unpack(recv.r, recv.w-recv.r, "_uu", &c, &n) < 0)
break;
if(c != recv.chan)
break;
send.win += n;
if(send.win >= send.pkt)
rwakeup(&send);
return;
case MSG_CHANNEL_REQUEST: case MSG_CHANNEL_REQUEST:
if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0) if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
break; break;
if(c != 0) if(c != recv.chan)
break; break;
if(n == 11 && memcmp(s, "exit-signal", n) == 0){ if(n == 11 && memcmp(s, "exit-signal", n) == 0){
if(unpack(p, recv.w-p, "s", &s, &n) < 0) if(unpack(p, recv.w-p, "s", &s, &n) < 0)
@ -1044,7 +1071,6 @@ usage(void)
void void
main(int argc, char *argv[]) main(int argc, char *argv[])
{ {
static char buf[8*1024];
static QLock sl; static QLock sl;
int b, n, c; int b, n, c;
char *s; char *s;
@ -1106,6 +1132,8 @@ main(int argc, char *argv[])
sysfatal("bad server version: %s", recv.v); sysfatal("bad server version: %s", recv.v);
recv.v = strdup(recv.v); recv.v = strdup(recv.v);
send.l = recv.l = &sl;
kex(0); kex(0);
if(user == nil) if(user == nil)
@ -1124,12 +1152,16 @@ Next0: switch(recvpkt()){
if(pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0) if(pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0)
sysfatal("auth: %r"); sysfatal("auth: %r");
recv.pkt = MaxPacket;
recv.win = WinPackets*recv.pkt;
recv.chan = 0;
/* open hailing frequencies */ /* open hailing frequencies */
sendpkt("bsuuu", MSG_CHANNEL_OPEN, sendpkt("bsuuu", MSG_CHANNEL_OPEN,
"session", 7, "session", 7,
0, recv.chan,
8*sizeof(buf), recv.win,
sizeof(buf)); recv.pkt);
Next1: switch(recvpkt()){ Next1: switch(recvpkt()){
default: default:
@ -1143,6 +1175,11 @@ Next1: switch(recvpkt()){
break; break;
} }
if(unpack(recv.r, recv.w-recv.r, "_uuuu", &recv.chan, &send.chan, &send.win, &send.pkt) < 0)
sysfatal("bad channel open confirmation");
if(send.pkt <= 0 || send.pkt > MaxPacket)
send.pkt = MaxPacket;
notify(catch); notify(catch);
atexit(shutdown); atexit(shutdown);
@ -1170,7 +1207,7 @@ Next1: switch(recvpkt()){
if(raw) { if(raw) {
rawon(); rawon();
sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST, sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
0, send.chan,
"pty-req", 7, "pty-req", 7,
0, 0,
tty.term, strlen(tty.term), tty.term, strlen(tty.term),
@ -1182,26 +1219,27 @@ Next1: switch(recvpkt()){
} }
if(cmd == nil){ if(cmd == nil){
sendpkt("busb", MSG_CHANNEL_REQUEST, sendpkt("busb", MSG_CHANNEL_REQUEST,
0, send.chan,
"shell", 5, "shell", 5,
0); 0);
} else { } else {
sendpkt("busbs", MSG_CHANNEL_REQUEST, sendpkt("busbs", MSG_CHANNEL_REQUEST,
0, send.chan,
"exec", 4, "exec", 4,
0, 0,
cmd, strlen(cmd)); cmd, strlen(cmd));
} }
for(;;){ for(;;){
static uchar buf[MaxPacket];
qunlock(&sl); qunlock(&sl);
n = read(0, buf, sizeof(buf)); n = read(0, buf, send.pkt);
qlock(&sl); qlock(&sl);
if(send.eof) if(send.eof)
break; break;
if(n < 0 && wasintr()){ if(n < 0 && wasintr()){
if(!raw) break; if(!raw) break;
sendpkt("busbs", MSG_CHANNEL_REQUEST, sendpkt("busbs", MSG_CHANNEL_REQUEST,
0, send.chan,
"signal", 6, "signal", 6,
0, 0,
"INT", 3); "INT", 3);
@ -1210,12 +1248,15 @@ Next1: switch(recvpkt()){
} }
if(n <= 0) if(n <= 0)
break; break;
send.win -= n;
while(send.win < 0)
rsleep(&send);
sendpkt("bus", MSG_CHANNEL_DATA, sendpkt("bus", MSG_CHANNEL_DATA,
0, send.chan,
buf, n); buf, n);
} }
if(send.eof++ == 0) if(send.eof++ == 0)
sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, 0); sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan);
qunlock(&sl); qunlock(&sl);
exits(nil); exits(nil);