diff --git a/lib/drivers/ip/transport/tcp/event.c b/lib/drivers/ip/transport/tcp/event.c index 85b7fb29182..03475c890ce 100644 --- a/lib/drivers/ip/transport/tcp/event.c +++ b/lib/drivers/ip/transport/tcp/event.c @@ -46,13 +46,6 @@ BucketCompletionWorker(PVOID Context) ExFreePoolWithTag(Bucket, TDI_BUCKET_TAG); } -static -VOID -SocketContextCloseWorker(PVOID Context) -{ - LibTCPClose(Context); -} - static VOID CompleteBucket(PCONNECTION_ENDPOINT Connection, PTDI_BUCKET Bucket, BOOLEAN Synchronous) @@ -173,7 +166,7 @@ TCPAcceptEventHandler(void *arg, struct tcp_pcb *newpcb) PIRP Irp; NTSTATUS Status; KIRQL OldIrql; - void *OldSocketContext; + struct tcp_pcb* OldSocketContext; DbgPrint("[IP, TCPAcceptEventHandler] Called\n"); @@ -222,9 +215,9 @@ TCPAcceptEventHandler(void *arg, struct tcp_pcb *newpcb) /* sanity assert...this should never be in anything else but a CLOSED state */ ASSERT(((struct tcp_pcb*)OldSocketContext)->state == CLOSED); - /* free socket context created in FileOpenConnection, as we're using a new - one; we free it asynchornously because otherwise we create a deadlock */ - ChewCreate(SocketContextCloseWorker, OldSocketContext); + + /* free socket context created in FileOpenConnection, as we're using a new one */ + LibTCPClose(OldSocketContext, TRUE); } DereferenceObject(Bucket->AssociatedEndpoint); @@ -279,7 +272,7 @@ TCPSendEventHandler(void *arg, u16_t space) Status = TCPTranslateError(LibTCPSend(Connection->SocketContext, SendBuffer, - SendLen)); + SendLen, TRUE)); TI_DbgPrint(DEBUG_TCP,("TCP Bytes: %d\n", SendLen)); diff --git a/lib/drivers/lwip/src/include/rosip.h b/lib/drivers/lwip/src/include/rosip.h index f62ab195e59..6f8648c161f 100755 --- a/lib/drivers/lwip/src/include/rosip.h +++ b/lib/drivers/lwip/src/include/rosip.h @@ -16,10 +16,10 @@ extern u32_t TCPRecvEventHandler(void *arg, struct pbuf *p); struct tcp_pcb *LibTCPSocket(void *arg); err_t LibTCPBind(struct tcp_pcb *pcb, struct ip_addr *ipaddr, u16_t port); struct tcp_pcb *LibTCPListen(struct tcp_pcb *pcb, u8_t backlog); -err_t LibTCPSend(struct tcp_pcb *pcb, void *dataptr, u16_t len); +err_t LibTCPSend(struct tcp_pcb *pcb, const void *dataptr, const u16_t len, const int safe); err_t LibTCPConnect(struct tcp_pcb *pcb, struct ip_addr *ipaddr, u16_t port); err_t LibTCPShutdown(struct tcp_pcb *pcb, int shut_rx, int shut_tx); -err_t LibTCPClose(struct tcp_pcb *pcb); +err_t LibTCPClose(struct tcp_pcb *pcb, const int safe); err_t LibTCPGetPeerName(struct tcp_pcb *pcb, struct ip_addr *ipaddr, u16_t *port); err_t LibTCPGetHostName(struct tcp_pcb *pcb, struct ip_addr *ipaddr, u16_t *port); void LibTCPAccept(struct tcp_pcb *pcb, struct tcp_pcb *listen_pcb, void *arg); diff --git a/lib/drivers/lwip/src/rostcp.c b/lib/drivers/lwip/src/rostcp.c index 049454de6ac..1c6cb2c20ad 100755 --- a/lib/drivers/lwip/src/rostcp.c +++ b/lib/drivers/lwip/src/rostcp.c @@ -437,35 +437,58 @@ LibTCPSendCallback(void *arg) } err_t -LibTCPSend(struct tcp_pcb *pcb, void *dataptr, u16_t len) +LibTCPSend(struct tcp_pcb *pcb, const void *dataptr, const u16_t len, const int safe) { - struct send_callback_msg *msg; err_t ret; if (!pcb) return ERR_CLSD; - - msg = ExAllocatePool(NonPagedPool, sizeof(struct send_callback_msg)); - if (msg) + + /* + If we're being called from a handler it means we're in the conetxt of teh tcpip + main thread. Therefore we don't have to queue our request via a callback and we + can execute immediately. + */ + if (safe) { - KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); - msg->Pcb = pcb; - msg->Data = dataptr; - msg->DataLength = len; - - tcpip_callback_with_block(LibTCPSendCallback, msg, 1); - - if (WaitForEventSafely(&msg->Event)) - ret = msg->Error; + if (tcp_sndbuf(pcb) < len) + { + ret = ERR_INPROGRESS; + } else - ret = ERR_CLSD; - - DbgPrint("LibTCPSend(0x%x)\n", pcb); - - ExFreePool(msg); - + { + ret = tcp_write(pcb, dataptr, len, TCP_WRITE_FLAG_COPY); + tcp_output(pcb); + } + return ret; } + else + { + struct send_callback_msg *msg; + + msg = ExAllocatePool(NonPagedPool, sizeof(struct send_callback_msg)); + if (msg) + { + KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); + msg->Pcb = pcb; + msg->Data = dataptr; + msg->DataLength = len; + + tcpip_callback_with_block(LibTCPSendCallback, msg, 1); + + if (WaitForEventSafely(&msg->Event)) + ret = msg->Error; + else + ret = ERR_CLSD; + + DbgPrint("LibTCPSend(0x%x)\n", pcb); + + ExFreePool(msg); + + return ret; + } + } return ERR_MEM; } @@ -496,9 +519,7 @@ LibTCPConnectCallback(void *arg) tcp_recv(msg->Pcb, InternalRecvEventHandler); tcp_sent(msg->Pcb, InternalSendEventHandler); - - //if (msg->Error == ERR_OK) - // msg->Error = ERR_INPROGRESS; + err_t Error = tcp_connect(msg->Pcb, msg->IpAddress, ntohs(msg->Port), InternalConnectEventHandler); msg->Error = Error == ERR_OK ? ERR_INPROGRESS : Error; @@ -648,9 +669,8 @@ LibTCPCloseCallback(void *arg) } err_t -LibTCPClose(struct tcp_pcb *pcb) +LibTCPClose(struct tcp_pcb *pcb, const int safe) { - struct close_callback_msg *msg; err_t ret; DbgPrint("[lwIP, LibTCPClose] Called on pcb = 0x%x\n", pcb); @@ -681,33 +701,59 @@ LibTCPClose(struct tcp_pcb *pcb) tcp_accept(pcb, NULL); DbgPrint("[lwIP, LibTCPClose] Attempting to allocate memory for msg\n"); - - msg = ExAllocatePool(NonPagedPool, sizeof(struct close_callback_msg)); - if (msg) + + /* + If we're being called from a handler it means we're in the conetxt of teh tcpip + main thread. Therefore we don't have to queue our request via a callback and we + can execute immediately. + */ + if (safe) { - DbgPrint("[lwIP, LibTCPClose] Initializing msg->Event\n"); - KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); - - DbgPrint("[lwIP, LibTCPClose] Initializing msg->pcb = 0x%x\n", pcb); - msg->Pcb = pcb; - - DbgPrint("[lwIP, LibTCPClose] Attempting to call LibTCPCloseCallback\n"); - - tcpip_callback_with_block(LibTCPCloseCallback, msg, 1); - - if (WaitForEventSafely(&msg->Event)) - ret = msg->Error; + if (pcb->state == LISTEN) + { + DbgPrint("[lwIP, LibTCPClose] Closing a listener\n"); + ret = tcp_close(pcb); + } else - ret = ERR_CLSD; - - ExFreePool(msg); - - DbgPrint("[lwIP, LibTCPClose] pcb = 0x%x\n", pcb); + { + DbgPrint("[lwIP, LibTCPClose] Aborting a connection\n"); + tcp_abort(pcb); + ret = ERR_OK; + } - DbgPrint("[lwIP, LibTCPClose] Done\n"); - return ret; } + else + { + struct close_callback_msg *msg; + + msg = ExAllocatePool(NonPagedPool, sizeof(struct close_callback_msg)); + if (msg) + { + DbgPrint("[lwIP, LibTCPClose] Initializing msg->Event\n"); + KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); + + DbgPrint("[lwIP, LibTCPClose] Initializing msg->pcb = 0x%x\n", pcb); + msg->Pcb = pcb; + + DbgPrint("[lwIP, LibTCPClose] Attempting to call LibTCPCloseCallback\n"); + + tcpip_callback_with_block(LibTCPCloseCallback, msg, 1); + + if (WaitForEventSafely(&msg->Event)) + ret = msg->Error; + else + ret = ERR_CLSD; + + ExFreePool(msg); + + DbgPrint("[lwIP, LibTCPClose] pcb = 0x%x\n", pcb); + + DbgPrint("[lwIP, LibTCPClose] Done\n"); + + return ret; + } + } DbgPrint("[lwIP, LibTCPClose] Failed to allocate memory\n");