sshnet: fix eof and close handling, use proper packet size, cleanup

This commit is contained in:
cinap_lenrek 2019-04-03 10:49:47 +02:00
parent 3bb1804631
commit a278545e3c

View file

@ -61,6 +61,7 @@ struct Client
int state; int state;
int num; int num;
int servernum; int servernum;
int sentclose;
char *connect; char *connect;
int sendpkt; int sendpkt;
@ -68,6 +69,8 @@ struct Client
int recvwin; int recvwin;
int recvacc; int recvacc;
int eof;
Req *wq; Req *wq;
Req **ewq; Req **ewq;
@ -91,7 +94,8 @@ enum {
MSG_CHANNEL_SUCCESS, MSG_CHANNEL_SUCCESS,
MSG_CHANNEL_FAILURE, MSG_CHANNEL_FAILURE,
MaxPacket = 1<<15, Overhead = 256,
MaxPacket = (1<<15)-256, /* 32K is maxatomic for pipe */
WinPackets = 8, WinPackets = 8,
SESSIONCHAN = 1<<24, SESSIONCHAN = 1<<24,
@ -104,7 +108,7 @@ struct Msg
uchar *rp; uchar *rp;
uchar *wp; uchar *wp;
uchar *ep; uchar *ep;
uchar buf[MaxPacket]; uchar buf[MaxPacket + Overhead];
}; };
#define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u) #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
@ -116,6 +120,7 @@ char *mtpt;
int sshfd; int sshfd;
int localport; int localport;
char localip[] = "::"; char localip[] = "::";
char Ehangup[] = "hangup on network connection";
int int
vpack(uchar *p, int n, char *fmt, va_list a) vpack(uchar *p, int n, char *fmt, va_list a)
@ -341,12 +346,10 @@ matchrmsgs(Client *c)
Msg *m; Msg *m;
int n, rm; int n, rm;
while(c->rq != nil && c->mq != nil){ while((r = c->rq) != nil && (m = c->mq) != nil){
r = c->rq;
c->rq = r->aux; c->rq = r->aux;
r->aux = nil;
rm = 0; rm = 0;
m = c->mq;
n = r->ifcall.count; n = r->ifcall.count;
if(n >= m->wp - m->rp){ if(n >= m->wp - m->rp){
n = m->wp - m->rp; n = m->wp - m->rp;
@ -362,6 +365,15 @@ matchrmsgs(Client *c)
respond(r, nil); respond(r, nil);
adjustwin(c, n); adjustwin(c, n);
} }
if(c->eof){
while((r = c->rq) != nil){
c->rq = r->aux;
r->aux = nil;
r->ofcall.count = 0;
respond(r, nil);
}
}
} }
void void
@ -438,57 +450,48 @@ dialedclient(Client *c)
} }
void void
teardownclient(Client *c) hangupclient(Client *c)
{ {
c->state = Teardown; Req *r;
sendmsg(pack(nil, "bu", MSG_CHANNEL_EOF, c->servernum));
c->eof = 1;
c->recvwin = 0;
c->sendwin = 0;
while((r = c->wq) != nil){
c->wq = r->aux;
r->aux = nil;
respond(r, Ehangup);
}
if(c->state == Established){
c->state = Teardown;
matchrmsgs(c);
return;
}
c->state = Closed;
} }
void void
hangupclient(Client *c) teardownclient(Client *c)
{ {
Req *r, *next; hangupclient(c);
Msg *m, *mnext; if(c->sentclose++ == 0)
sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
c->state = Closed;
for(m=c->mq; m; m=mnext){
mnext = m->link;
free(m);
}
c->mq = nil;
for(r=c->rq; r; r=next){
next = r->aux;
respond(r, "hangup on network connection");
}
c->rq = nil;
for(r=c->wq; r; r=next){
next = r->aux;
respond(r, "hangup on network connection");
}
c->wq = nil;
} }
void void
closeclient(Client *c) closeclient(Client *c)
{ {
Msg *m, *next; Msg *m;
if(--c->ref) if(--c->ref)
return; return;
if(c->state >= Established)
if(c->rq != nil || c->wq != nil) teardownclient(c);
sysfatal("ref count reached zero with requests pending (BUG)"); while((m = c->mq) != nil){
c->mq = m->link;
for(m=c->mq; m; m=next){
next = m->link;
free(m); free(m);
} }
c->mq = nil;
if(c->state != Closed)
teardownclient(c);
} }
void void
sshreadproc(void*) sshreadproc(void*)
@ -810,6 +813,7 @@ ctlwrite(Req *r, Client *c)
nf = getfields(f[1], f, nelem(f), 0, "!"); nf = getfields(f[1], f, nelem(f), 0, "!");
if(nf != 2) if(nf != 2)
goto Badarg; goto Badarg;
c->eof = 0;
c->sendwin = MaxPacket; c->sendwin = MaxPacket;
c->recvwin = WinPackets * MaxPacket; c->recvwin = WinPackets * MaxPacket;
c->recvacc = 0; c->recvacc = 0;
@ -831,7 +835,7 @@ ctlwrite(Req *r, Client *c)
static void static void
dataread(Req *r, Client *c) dataread(Req *r, Client *c)
{ {
if(c->state != Established){ if(c->state < Established){
respond(r, "not connected"); respond(r, "not connected");
return; return;
} }
@ -1028,7 +1032,7 @@ fsflush(Req *r)
static void static void
handlemsg(Msg *m) handlemsg(Msg *m)
{ {
int chan, win, pkt, n, l; int chan, win, pkt, n;
Client *c; Client *c;
char *s; char *s;
@ -1037,7 +1041,7 @@ handlemsg(Msg *m)
if(unpack(m, "_uu", &chan, &n) < 0) if(unpack(m, "_uu", &chan, &n) < 0)
break; break;
c = getclient(chan); c = getclient(chan);
if(c != nil && c->state==Established){ if(c != nil && c->state == Established){
c->sendwin += n; c->sendwin += n;
procwreqs(c); procwreqs(c);
} }
@ -1046,7 +1050,9 @@ handlemsg(Msg *m)
if(unpack(m, "_us", &chan, &s, &n) < 0) if(unpack(m, "_us", &chan, &s, &n) < 0)
break; break;
c = getclient(chan); c = getclient(chan);
if(c != nil && c->state==Established){ if(c != nil && c->state == Established){
if(c->recvwin <= 0)
break;
c->recvwin -= n; c->recvwin -= n;
m->rp = (uchar*)s; m->rp = (uchar*)s;
queuermsg(c, m); queuermsg(c, m);
@ -1058,18 +1064,17 @@ handlemsg(Msg *m)
if(unpack(m, "_u", &chan) < 0) if(unpack(m, "_u", &chan) < 0)
break; break;
c = getclient(chan); c = getclient(chan);
if(c != nil){ if(c != nil && c->state == Established){
hangupclient(c); c->eof = 1;
m->rp = m->wp = m->buf; c->recvwin = 0;
sendmsg(pack(m, "bu", MSG_CHANNEL_CLOSE, c->servernum)); matchrmsgs(c);
return;
} }
break; break;
case MSG_CHANNEL_CLOSE: case MSG_CHANNEL_CLOSE:
if(unpack(m, "_u", &chan) < 0) if(unpack(m, "_u", &chan) < 0)
break; break;
c = getclient(chan); c = getclient(chan);
if(c != nil) if(c != nil && c->state >= Established)
hangupclient(c); hangupclient(c);
break; break;
case MSG_CHANNEL_OPEN_CONFIRMATION: case MSG_CHANNEL_OPEN_CONFIRMATION:
@ -1087,20 +1092,20 @@ handlemsg(Msg *m)
c->sendpkt = pkt; c->sendpkt = pkt;
c->sendwin = win; c->sendwin = win;
c->servernum = n; c->servernum = n;
c->sentclose = 0;
c->state = Established; c->state = Established;
dialedclient(c); dialedclient(c);
break; break;
case MSG_CHANNEL_OPEN_FAILURE: case MSG_CHANNEL_OPEN_FAILURE:
if(unpack(m, "_uus", &chan, &n, &s, &l) < 0) if(unpack(m, "_u____s", &chan, &s, &n) < 0)
break; break;
if(chan == SESSIONCHAN){ if(chan == SESSIONCHAN){
sendp(ssherrchan, smprint("%.*s", utfnlen(s, l), s)); sendp(ssherrchan, smprint("%.*s", utfnlen(s, n), s));
break; break;
} }
c = getclient(chan); c = getclient(chan);
if(c == nil || c->state != Dialing) if(c == nil || c->state != Dialing)
break; break;
c->servernum = n;
c->state = Closed; c->state = Closed;
dialedclient(c); dialedclient(c);
break; break;