- Modify the signatures of the lwIP wrapper interfaces. This is the first step toward fixing some nasty crashes related to race conditions.
[TCPIP]
- Call the new lwIP wrapper interfaces appropriately.

svn path=/branches/GSoC_2011/TcpIpDriver/; revision=52767
This commit is contained in:
Claudiu Mihail 2011-07-21 20:58:54 +00:00
parent 781643393b
commit 00bfbb576b
6 changed files with 163 additions and 129 deletions

View file

@ -67,7 +67,7 @@ NTSTATUS TCPListen(PCONNECTION_ENDPOINT Connection, UINT Backlog)
AddressToBind.addr = Connection->AddressFile->Address.Address.IPv4Address; AddressToBind.addr = Connection->AddressFile->Address.Address.IPv4Address;
Status = TCPTranslateError(LibTCPBind(Connection->SocketContext, Status = TCPTranslateError(LibTCPBind(Connection,
&AddressToBind, &AddressToBind,
Connection->AddressFile->Port)); Connection->AddressFile->Port));
@ -91,7 +91,7 @@ NTSTATUS TCPListen(PCONNECTION_ENDPOINT Connection, UINT Backlog)
if (NT_SUCCESS(Status)) if (NT_SUCCESS(Status))
{ {
Connection->SocketContext = LibTCPListen(Connection->SocketContext, Backlog); Connection->SocketContext = LibTCPListen(Connection, Backlog);
if (!Connection->SocketContext) if (!Connection->SocketContext)
Status = STATUS_UNSUCCESSFUL; Status = STATUS_UNSUCCESSFUL;
} }

View file

@ -143,30 +143,32 @@ FlushAllQueues(PCONNECTION_ENDPOINT Connection, NTSTATUS Status)
VOID VOID
TCPFinEventHandler(void *arg, err_t err) TCPFinEventHandler(void *arg, err_t err)
{ {
PCONNECTION_ENDPOINT Connection = arg; PCONNECTION_ENDPOINT Connection = (PCONNECTION_ENDPOINT)arg;
DbgPrint("[IP, TCPFinEventHandler] Called for Connection( 0x%x )-> SocketContext = pcb (0x%x)\n", Connection, Connection->SocketContext);
/* Only clear the pointer if the shutdown was caused by an error */ /* Only clear the pointer if the shutdown was caused by an error */
if (err != ERR_OK) if (err != ERR_OK)
{ {
/* We're already closed by the error so we don't want to call lwip_close */ /* We're already closed by the error so we don't want to call lwip_close */
DbgPrint("[IP, TCPFinEventHandler] MAKING Connection( 0x%x )-> SocketContext = pcb (0x%x) NULL\n", Connection, Connection->SocketContext);
Connection->SocketContext = NULL; Connection->SocketContext = NULL;
} }
DbgPrint("[IP, TCPFinEventHandler] Called for Connection( 0x%x )-> SocketContext = pcb (0x%x)\n", Connection, Connection->SocketContext);
FlushAllQueues(Connection, TCPTranslateError(err)); FlushAllQueues(Connection, TCPTranslateError(err));
DbgPrint("[IP, TCPFinEventHandler] Done\n");
} }
VOID VOID
TCPAcceptEventHandler(void *arg, struct tcp_pcb *newpcb) TCPAcceptEventHandler(void *arg, struct tcp_pcb *newpcb)
{ {
PCONNECTION_ENDPOINT Connection = arg; PCONNECTION_ENDPOINT Connection = (PCONNECTION_ENDPOINT)arg;
PTDI_BUCKET Bucket; PTDI_BUCKET Bucket;
PLIST_ENTRY Entry; PLIST_ENTRY Entry;
PIRP Irp; PIRP Irp;
NTSTATUS Status; NTSTATUS Status;
KIRQL OldIrql; KIRQL OldIrql;
struct tcp_pcb* OldSocketContext;
DbgPrint("[IP, TCPAcceptEventHandler] Called\n"); DbgPrint("[IP, TCPAcceptEventHandler] Called\n");
@ -202,22 +204,19 @@ TCPAcceptEventHandler(void *arg, struct tcp_pcb *newpcb)
LockObject(Bucket->AssociatedEndpoint, &OldIrql); LockObject(Bucket->AssociatedEndpoint, &OldIrql);
/* sanity assert...this should never be in anything else but a CLOSED state */
ASSERT(((struct tcp_pcb*)Bucket->AssociatedEndpoint->SocketContext)->state == CLOSED);
/* free socket context created in FileOpenConnection, as we're using a new one */
LibTCPClose(Bucket->AssociatedEndpoint, TRUE);
/* free previously created socket context (we don't use it, we use newpcb) */ /* free previously created socket context (we don't use it, we use newpcb) */
OldSocketContext = Bucket->AssociatedEndpoint->SocketContext;
Bucket->AssociatedEndpoint->SocketContext = newpcb; Bucket->AssociatedEndpoint->SocketContext = newpcb;
LibTCPAccept(newpcb, LibTCPAccept(newpcb, (PTCP_PCB)Connection->SocketContext, Bucket->AssociatedEndpoint);
(struct tcp_pcb*)Connection->SocketContext,
Bucket->AssociatedEndpoint);
DbgPrint("[IP, TCPAcceptEventHandler] Trying to unlock Bucket->AssociatedEndpoint\n"); DbgPrint("[IP, TCPAcceptEventHandler] Trying to unlock Bucket->AssociatedEndpoint\n");
UnlockObject(Bucket->AssociatedEndpoint, OldIrql); UnlockObject(Bucket->AssociatedEndpoint, OldIrql);
/* sanity assert...this should never be in anything else but a CLOSED state */
ASSERT(((struct tcp_pcb*)OldSocketContext)->state == CLOSED);
/* free socket context created in FileOpenConnection, as we're using a new one */
LibTCPClose(OldSocketContext, TRUE);
} }
DereferenceObject(Bucket->AssociatedEndpoint); DereferenceObject(Bucket->AssociatedEndpoint);
@ -270,7 +269,7 @@ TCPSendEventHandler(void *arg, u16_t space)
("Connection->SocketContext: %x\n", ("Connection->SocketContext: %x\n",
Connection->SocketContext)); Connection->SocketContext));
Status = TCPTranslateError(LibTCPSend(Connection->SocketContext, Status = TCPTranslateError(LibTCPSend(Connection,
SendBuffer, SendBuffer,
SendLen, TRUE)); SendLen, TRUE));
@ -368,7 +367,7 @@ TCPRecvEventHandler(void *arg, struct pbuf *p)
VOID VOID
TCPConnectEventHandler(void *arg, err_t err) TCPConnectEventHandler(void *arg, err_t err)
{ {
PCONNECTION_ENDPOINT Connection = arg; PCONNECTION_ENDPOINT Connection = (PCONNECTION_ENDPOINT)arg;
PTDI_BUCKET Bucket; PTDI_BUCKET Bucket;
PLIST_ENTRY Entry; PLIST_ENTRY Entry;

View file

@ -115,19 +115,19 @@ NTSTATUS TCPClose
DbgPrint("[IP, TCPClose] Called for Connection( 0x%x )->SocketConext( 0x%x )\n", Connection, Connection->SocketContext); DbgPrint("[IP, TCPClose] Called for Connection( 0x%x )->SocketConext( 0x%x )\n", Connection, Connection->SocketContext);
Socket = Connection->SocketContext; Socket = Connection->SocketContext;
Connection->SocketContext = NULL; //Connection->SocketContext = NULL;
/* We should not be associated to an address file at this point */ /* We should not be associated to an address file at this point */
ASSERT(!Connection->AddressFile); ASSERT(!Connection->AddressFile);
/* Don't try to close again if the other side closed us already */ /* Don't try to close again if the other side closed us already */
if (Socket) if (Connection->SocketContext)
{ {
FlushAllQueues(Connection, STATUS_CANCELLED); FlushAllQueues(Connection, STATUS_CANCELLED);
DbgPrint("[IP, TCPClose] Socket (pcb) = 0x%x\n", Socket); DbgPrint("[IP, TCPClose] Socket (pcb) = 0x%x\n", Socket);
LibTCPClose(Socket, FALSE); LibTCPClose(Connection, FALSE);
} }
DbgPrint("[IP, TCPClose] Leaving. Connection->RefCount = %d\n", Connection->RefCount); DbgPrint("[IP, TCPClose] Leaving. Connection->RefCount = %d\n", Connection->RefCount);
@ -299,7 +299,7 @@ NTSTATUS TCPConnect
bindaddr.addr = Connection->AddressFile->Address.Address.IPv4Address; bindaddr.addr = Connection->AddressFile->Address.Address.IPv4Address;
} }
Status = TCPTranslateError(LibTCPBind(Connection->SocketContext, Status = TCPTranslateError(LibTCPBind(Connection,
&bindaddr, &bindaddr,
Connection->AddressFile->Port)); Connection->AddressFile->Port));
@ -338,7 +338,7 @@ NTSTATUS TCPConnect
InsertTailList( &Connection->ConnectRequest, &Bucket->Entry ); InsertTailList( &Connection->ConnectRequest, &Bucket->Entry );
Status = TCPTranslateError(LibTCPConnect(Connection->SocketContext, Status = TCPTranslateError(LibTCPConnect(Connection,
&connaddr, &connaddr,
RemotePort)); RemotePort));
@ -372,12 +372,12 @@ NTSTATUS TCPDisconnect
{ {
if (Flags & TDI_DISCONNECT_RELEASE) if (Flags & TDI_DISCONNECT_RELEASE)
{ {
Status = TCPTranslateError(LibTCPShutdown(Connection->SocketContext, 0, 1)); Status = TCPTranslateError(LibTCPShutdown(Connection, 0, 1));
} }
if ((Flags & TDI_DISCONNECT_ABORT) || !Flags) if ((Flags & TDI_DISCONNECT_ABORT) || !Flags)
{ {
Status = TCPTranslateError(LibTCPShutdown(Connection->SocketContext, 1, 1)); Status = TCPTranslateError(LibTCPShutdown(Connection, 1, 1));
} }
} }
else else
@ -464,7 +464,7 @@ NTSTATUS TCPSendData
Connection->SocketContext)); Connection->SocketContext));
DbgPrint("[IP, TCPSendData] Called\n"); DbgPrint("[IP, TCPSendData] Called\n");
Status = TCPTranslateError(LibTCPSend(Connection->SocketContext, Status = TCPTranslateError(LibTCPSend(Connection,
BufferData, BufferData,
SendLength, SendLength,
FALSE)); FALSE));

View file

@ -1,3 +1,8 @@
include_directories(
BEFORE include
${REACTOS_SOURCE_DIR}/drivers/network/tcpip/include
${REACTOS_SOURCE_DIR}/lib/drivers/lwip/src/include
${REACTOS_SOURCE_DIR}/lib/drivers/lwip/src/include/ipv4)
include_directories( include_directories(
src/include src/include

View file

@ -4,6 +4,9 @@
#include "lwip/tcp.h" #include "lwip/tcp.h"
#include "lwip/pbuf.h" #include "lwip/pbuf.h"
#include "lwip/ip_addr.h" #include "lwip/ip_addr.h"
#include "tcpip.h"
typedef struct tcp_pcb* PTCP_PCB;
/* External TCP event handlers */ /* External TCP event handlers */
extern void TCPConnectEventHandler(void *arg, const err_t err); extern void TCPConnectEventHandler(void *arg, const err_t err);
@ -13,16 +16,17 @@ extern void TCPFinEventHandler(void *arg, const err_t err);
extern u32_t TCPRecvEventHandler(void *arg, struct pbuf *p); extern u32_t TCPRecvEventHandler(void *arg, struct pbuf *p);
/* TCP functions */ /* TCP functions */
struct tcp_pcb *LibTCPSocket(void *arg); PTCP_PCB LibTCPSocket(void *arg);
err_t LibTCPBind(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, const u16_t port); err_t LibTCPBind(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port);
struct tcp_pcb *LibTCPListen(struct tcp_pcb *pcb, const u8_t backlog); PTCP_PCB LibTCPListen(PCONNECTION_ENDPOINT Connection, const u8_t backlog);
err_t LibTCPSend(struct tcp_pcb *pcb, void *const dataptr, const u16_t len, const int safe); err_t LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, const int safe);
err_t LibTCPConnect(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, const u16_t port); err_t LibTCPConnect(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port);
err_t LibTCPShutdown(struct tcp_pcb *pcb, const int shut_rx, const int shut_tx); err_t LibTCPShutdown(PCONNECTION_ENDPOINT Connection, const int shut_rx, const int shut_tx);
err_t LibTCPClose(struct tcp_pcb *pcb, const int safe); err_t LibTCPClose(PCONNECTION_ENDPOINT Connection, const int safe);
err_t LibTCPGetPeerName(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, u16_t *const port);
err_t LibTCPGetHostName(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, u16_t *const port); err_t LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr *const ipaddr, u16_t *const port);
void LibTCPAccept(struct tcp_pcb *pcb, struct tcp_pcb *listen_pcb, void *arg); err_t LibTCPGetHostName(PTCP_PCB pcb, struct ip_addr *const ipaddr, u16_t *const port);
void LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg);
/* IP functions */ /* IP functions */
void LibIPInsertPacket(void *ifarg, const void *const data, const u32_t size); void LibIPInsertPacket(void *ifarg, const void *const data, const u32_t size);

View file

@ -272,7 +272,7 @@ struct bind_callback_msg
KEVENT Event; KEVENT Event;
/* Input */ /* Input */
struct tcp_pcb *Pcb; PCONNECTION_ENDPOINT Connection;
struct ip_addr *IpAddress; struct ip_addr *IpAddress;
u16_t Port; u16_t Port;
@ -288,18 +288,20 @@ LibTCPBindCallback(void *arg)
ASSERT(msg); ASSERT(msg);
msg->Error = tcp_bind(msg->Pcb, msg->IpAddress, ntohs(msg->Port)); msg->Error = tcp_bind((PTCP_PCB)msg->Connection->SocketContext,
msg->IpAddress,
ntohs(msg->Port));
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE); KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
} }
err_t err_t
LibTCPBind(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, const u16_t port) LibTCPBind(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
{ {
struct bind_callback_msg *msg; struct bind_callback_msg *msg;
err_t ret; err_t ret;
if (!pcb) if (!Connection->SocketContext)
return ERR_CLSD; return ERR_CLSD;
DbgPrint("[lwIP, LibTCPBind] Called\n"); DbgPrint("[lwIP, LibTCPBind] Called\n");
@ -308,7 +310,7 @@ LibTCPBind(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, const u16_t port)
if (msg) if (msg)
{ {
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
msg->Pcb = pcb; msg->Connection = Connection;
msg->IpAddress = ipaddr; msg->IpAddress = ipaddr;
msg->Port = port; msg->Port = port;
@ -319,7 +321,7 @@ LibTCPBind(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, const u16_t port)
else else
ret = ERR_CLSD; ret = ERR_CLSD;
DbgPrint("[lwIP, LibTCPBind] pcb = 0x%x\n", pcb); DbgPrint("[lwIP, LibTCPBind] pcb = 0x%x\n", Connection->SocketContext);
DbgPrint("[lwIP, LibTCPBind] Done\n"); DbgPrint("[lwIP, LibTCPBind] Done\n");
@ -337,11 +339,11 @@ struct listen_callback_msg
KEVENT Event; KEVENT Event;
/* Input */ /* Input */
struct tcp_pcb *Pcb; PCONNECTION_ENDPOINT Connection;
u8_t Backlog; u8_t Backlog;
/* Output */ /* Output */
struct tcp_pcb *NewPcb; PTCP_PCB NewPcb;
}; };
static static
@ -354,7 +356,7 @@ LibTCPListenCallback(void *arg)
DbgPrint("[lwIP, LibTCPListenCallback] Called\n"); DbgPrint("[lwIP, LibTCPListenCallback] Called\n");
msg->NewPcb = tcp_listen_with_backlog(msg->Pcb, msg->Backlog); msg->NewPcb = tcp_listen_with_backlog((PTCP_PCB)msg->Connection->SocketContext, msg->Backlog);
if (msg->NewPcb) if (msg->NewPcb)
{ {
@ -366,22 +368,22 @@ LibTCPListenCallback(void *arg)
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE); KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
} }
struct tcp_pcb * PTCP_PCB
LibTCPListen(struct tcp_pcb *pcb, const u8_t backlog) LibTCPListen(PCONNECTION_ENDPOINT Connection, const u8_t backlog)
{ {
struct listen_callback_msg *msg; struct listen_callback_msg *msg;
void *ret; PTCP_PCB ret;
DbgPrint("[lwIP, LibTCPListen] Called on pcb = 0x%x\n", pcb); DbgPrint("[lwIP, LibTCPListen] Called on pcb = 0x%x\n", Connection->SocketContext);
if (!pcb) if (!Connection->SocketContext)
return NULL; return NULL;
msg = ExAllocatePool(NonPagedPool, sizeof(struct listen_callback_msg)); msg = ExAllocatePool(NonPagedPool, sizeof(struct listen_callback_msg));
if (msg) if (msg)
{ {
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
msg->Pcb = pcb; msg->Connection = Connection;
msg->Backlog = backlog; msg->Backlog = backlog;
tcpip_callback_with_block(LibTCPListenCallback, msg, 1); tcpip_callback_with_block(LibTCPListenCallback, msg, 1);
@ -392,7 +394,7 @@ LibTCPListen(struct tcp_pcb *pcb, const u8_t backlog)
ret = NULL; ret = NULL;
DbgPrint("[lwIP, LibTCPListen] pcb = 0x%x, newpcb = 0x%x, sizeof(pcb) = %d \n", DbgPrint("[lwIP, LibTCPListen] pcb = 0x%x, newpcb = 0x%x, sizeof(pcb) = %d \n",
pcb, ret, sizeof(struct tcp_pcb)); Connection->SocketContext, ret, sizeof(struct tcp_pcb));
DbgPrint("[lwIP, LibTCPListen] Done\n"); DbgPrint("[lwIP, LibTCPListen] Done\n");
@ -410,7 +412,7 @@ struct send_callback_msg
KEVENT Event; KEVENT Event;
/* Input */ /* Input */
struct tcp_pcb *Pcb; PCONNECTION_ENDPOINT Connection;
void *Data; void *Data;
u16_t DataLength; u16_t DataLength;
@ -426,26 +428,29 @@ LibTCPSendCallback(void *arg)
ASSERT(msg); ASSERT(msg);
if (tcp_sndbuf(msg->Pcb) < msg->DataLength) if (tcp_sndbuf((PTCP_PCB)msg->Connection->SocketContext) < msg->DataLength)
{ {
msg->Error = ERR_INPROGRESS; msg->Error = ERR_INPROGRESS;
} }
else else
{ {
msg->Error = tcp_write(msg->Pcb, msg->Data, msg->DataLength, TCP_WRITE_FLAG_COPY); msg->Error = tcp_write((PTCP_PCB)msg->Connection->SocketContext,
msg->Data,
msg->DataLength,
TCP_WRITE_FLAG_COPY);
tcp_output(msg->Pcb); tcp_output((PTCP_PCB)msg->Connection->SocketContext);
} }
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE); KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
} }
err_t err_t
LibTCPSend(struct tcp_pcb *pcb, void *const dataptr, const u16_t len, const int safe) LibTCPSend(PCONNECTION_ENDPOINT Connection, void *const dataptr, const u16_t len, const int safe)
{ {
err_t ret; err_t ret;
if (!pcb) if (!Connection->SocketContext)
return ERR_CLSD; return ERR_CLSD;
/* /*
@ -455,14 +460,14 @@ LibTCPSend(struct tcp_pcb *pcb, void *const dataptr, const u16_t len, const int
*/ */
if (safe) if (safe)
{ {
if (tcp_sndbuf(pcb) < len) if (tcp_sndbuf((PTCP_PCB)Connection->SocketContext) < len)
{ {
ret = ERR_INPROGRESS; ret = ERR_INPROGRESS;
} }
else else
{ {
ret = tcp_write(pcb, dataptr, len, TCP_WRITE_FLAG_COPY); ret = tcp_write((PTCP_PCB)Connection->SocketContext, dataptr, len, TCP_WRITE_FLAG_COPY);
tcp_output(pcb); tcp_output((PTCP_PCB)Connection->SocketContext);
} }
return ret; return ret;
@ -475,7 +480,7 @@ LibTCPSend(struct tcp_pcb *pcb, void *const dataptr, const u16_t len, const int
if (msg) if (msg)
{ {
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
msg->Pcb = pcb; msg->Connection = Connection;
msg->Data = dataptr; msg->Data = dataptr;
msg->DataLength = len; msg->DataLength = len;
@ -486,7 +491,7 @@ LibTCPSend(struct tcp_pcb *pcb, void *const dataptr, const u16_t len, const int
else else
ret = ERR_CLSD; ret = ERR_CLSD;
DbgPrint("[lwIP, LibTCPSend] pcb = 0x%x\n", pcb); DbgPrint("[lwIP, LibTCPSend] pcb = 0x%x\n", Connection->SocketContext);
ExFreePool(msg); ExFreePool(msg);
@ -503,7 +508,7 @@ struct connect_callback_msg
KEVENT Event; KEVENT Event;
/* Input */ /* Input */
struct tcp_pcb *Pcb; PCONNECTION_ENDPOINT Connection;
struct ip_addr *IpAddress; struct ip_addr *IpAddress;
u16_t Port; u16_t Port;
@ -521,10 +526,13 @@ LibTCPConnectCallback(void *arg)
ASSERT(arg); ASSERT(arg);
tcp_recv(msg->Pcb, InternalRecvEventHandler); tcp_recv((PTCP_PCB)msg->Connection->SocketContext, InternalRecvEventHandler);
tcp_sent(msg->Pcb, InternalSendEventHandler); tcp_sent((PTCP_PCB)msg->Connection->SocketContext, InternalSendEventHandler);
err_t Error = tcp_connect((PTCP_PCB)msg->Connection->SocketContext,
msg->IpAddress, ntohs(msg->Port),
InternalConnectEventHandler);
err_t Error = tcp_connect(msg->Pcb, msg->IpAddress, ntohs(msg->Port), InternalConnectEventHandler);
msg->Error = Error == ERR_OK ? ERR_INPROGRESS : Error; msg->Error = Error == ERR_OK ? ERR_INPROGRESS : Error;
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE); KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
@ -533,21 +541,21 @@ LibTCPConnectCallback(void *arg)
} }
err_t err_t
LibTCPConnect(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, const u16_t port) LibTCPConnect(PCONNECTION_ENDPOINT Connection, struct ip_addr *const ipaddr, const u16_t port)
{ {
struct connect_callback_msg *msg; struct connect_callback_msg *msg;
err_t ret; err_t ret;
DbgPrint("[lwIP, LibTCPConnect] Called\n"); DbgPrint("[lwIP, LibTCPConnect] Called\n");
if (!pcb) if (!Connection->SocketContext)
return ERR_CLSD; return ERR_CLSD;
msg = ExAllocatePool(NonPagedPool, sizeof(struct connect_callback_msg)); msg = ExAllocatePool(NonPagedPool, sizeof(struct connect_callback_msg));
if (msg) if (msg)
{ {
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
msg->Pcb = pcb; msg->Connection = Connection;
msg->IpAddress = ipaddr; msg->IpAddress = ipaddr;
msg->Port = port; msg->Port = port;
@ -562,7 +570,7 @@ LibTCPConnect(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, const u16_t por
ExFreePool(msg); ExFreePool(msg);
DbgPrint("[lwIP, LibTCPConnect] pcb = 0x%x\n", pcb); DbgPrint("[lwIP, LibTCPConnect] pcb = 0x%x\n", Connection->SocketContext);
DbgPrint("[lwIP, LibTCPConnect] Done\n"); DbgPrint("[lwIP, LibTCPConnect] Done\n");
@ -578,7 +586,7 @@ struct shutdown_callback_msg
KEVENT Event; KEVENT Event;
/* Input */ /* Input */
struct tcp_pcb *Pcb; PCONNECTION_ENDPOINT Connection;
int shut_rx; int shut_rx;
int shut_tx; int shut_tx;
@ -597,9 +605,12 @@ LibTCPShutdownCallback(void *arg)
it means lwIP will take care of it anyway and if it does so before us it will it means lwIP will take care of it anyway and if it does so before us it will
cause memory corruption. cause memory corruption.
*/ */
if ((msg->Pcb->state == ESTABLISHED) || (msg->Pcb->state == SYN_RCVD)) if ((((PTCP_PCB)msg->Connection->SocketContext)->state == ESTABLISHED) ||
(((PTCP_PCB)msg->Connection->SocketContext)->state == SYN_RCVD))
{ {
msg->Error = tcp_shutdown(msg->Pcb, msg->shut_rx, msg->shut_tx); msg->Error =
tcp_shutdown((PTCP_PCB)msg->Connection->SocketContext,
msg->shut_rx, msg->shut_tx);
} }
else else
msg->Error = ERR_OK; msg->Error = ERR_OK;
@ -608,27 +619,29 @@ LibTCPShutdownCallback(void *arg)
} }
err_t err_t
LibTCPShutdown(struct tcp_pcb *pcb, const int shut_rx, const int shut_tx) LibTCPShutdown(PCONNECTION_ENDPOINT Connection, const int shut_rx, const int shut_tx)
{ {
struct shutdown_callback_msg *msg; struct shutdown_callback_msg *msg;
err_t ret; err_t ret;
DbgPrint("[lwIP, LibTCPShutdown] Called on pcb = 0x%x, rx = %d, tx = %d\n", pcb, shut_rx, shut_tx); DbgPrint("[lwIP, LibTCPShutdown] Called on pcb = 0x%x, rx = %d, tx = %d\n",
Connection->SocketContext, shut_rx, shut_tx);
if (!pcb) if (!Connection->SocketContext)
{ {
DbgPrint("[lwIP, LibTCPShutdown] Done... NO pcb\n"); DbgPrint("[lwIP, LibTCPShutdown] Done... NO pcb\n");
return ERR_CLSD; return ERR_CLSD;
} }
DbgPrint("[lwIP, LibTCPShutdown] pcb->state = %s\n", tcp_state_str[pcb->state]); DbgPrint("[lwIP, LibTCPShutdown] pcb->state = %s\n",
tcp_state_str[((PTCP_PCB)Connection->SocketContext)->state]);
msg = ExAllocatePool(NonPagedPool, sizeof(struct shutdown_callback_msg)); msg = ExAllocatePool(NonPagedPool, sizeof(struct shutdown_callback_msg));
if (msg) if (msg)
{ {
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
msg->Pcb = pcb; msg->Connection = Connection;
msg->shut_rx = shut_rx; msg->shut_rx = shut_rx;
msg->shut_tx = shut_tx; msg->shut_tx = shut_tx;
@ -655,7 +668,7 @@ struct close_callback_msg
KEVENT Event; KEVENT Event;
/* Input */ /* Input */
struct tcp_pcb *Pcb; PCONNECTION_ENDPOINT Connection;
/* Output */ /* Output */
err_t Error; err_t Error;
@ -663,49 +676,9 @@ struct close_callback_msg
static static
void void
LibTCPCloseCallback(void *arg) CloseCallbacks(struct tcp_pcb *pcb)
{ {
struct close_callback_msg *msg = arg;
if (msg->Pcb->state == CLOSED)
{
DbgPrint("[lwIP, LibTCPCloseCallback] Connection was closed on us\n");
msg->Error = ERR_OK;
}
if (msg->Pcb->state == LISTEN)
{
DbgPrint("[lwIP, LibTCPCloseCallback] Closing a listener\n");
msg->Error = tcp_close(msg->Pcb);
}
else
{
DbgPrint("[lwIP, LibTCPCloseCallback] Aborting a connection\n");
tcp_abort(msg->Pcb);
msg->Error = ERR_OK;
}
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
}
err_t
//LibTCPClose(struct tcp_pcb *pcb, const int safe)
LibTCPClose(struct tcp_pcb *pcb, const int safe)
{
err_t ret;
DbgPrint("[lwIP, LibTCPClose] Called on pcb = 0x%x\n", pcb);
if (!pcb)
{
DbgPrint("[lwIP, LibTCPClose] Done... NO pcb\n");
return ERR_CLSD;
}
DbgPrint("[lwIP, LibTCPClose] pcb->state = %s\n", tcp_state_str[pcb->state]);
tcp_arg(pcb, NULL); tcp_arg(pcb, NULL);
/* /*
if this pcb is not in LISTEN state than it has if this pcb is not in LISTEN state than it has
valid recv, send and err callbacks to cancel valid recv, send and err callbacks to cancel
@ -718,6 +691,58 @@ LibTCPClose(struct tcp_pcb *pcb, const int safe)
} }
tcp_accept(pcb, NULL); tcp_accept(pcb, NULL);
}
static
void
LibTCPCloseCallback(void *arg)
{
struct close_callback_msg *msg = arg;
DbgPrint("[lwIP, LibTCPCloseCallback] pcb = 0x%x\n", (PTCP_PCB)msg->Connection->SocketContext);
if (!msg->Connection->SocketContext)
{
DbgPrint("[lwIP, LibTCPCloseCallback] NULL pcb...bail, bail!!!\n");
ASSERT(FALSE);
msg->Error = ERR_OK;
return;
}
CloseCallbacks((PTCP_PCB)msg->Connection->SocketContext);
if (((PTCP_PCB)msg->Connection->SocketContext)->state == LISTEN)
{
DbgPrint("[lwIP, LibTCPCloseCallback] Closing a listener\n");
msg->Error = tcp_close((PTCP_PCB)msg->Connection->SocketContext);
}
else
{
DbgPrint("[lwIP, LibTCPCloseCallback] Aborting a connection\n");
tcp_abort((PTCP_PCB)msg->Connection->SocketContext);
msg->Error = ERR_OK;
}
KeSetEvent(&msg->Event, IO_NO_INCREMENT, FALSE);
}
err_t
LibTCPClose(PCONNECTION_ENDPOINT Connection, const int safe)
{
err_t ret;
DbgPrint("[lwIP, LibTCPClose] Called on pcb = 0x%x\n", Connection->SocketContext);
if (!Connection->SocketContext)
{
DbgPrint("[lwIP, LibTCPClose] Done... NO pcb\n");
return ERR_CLSD;
}
DbgPrint("[lwIP, LibTCPClose] pcb->state = %s\n",
tcp_state_str[((PTCP_PCB)Connection->SocketContext)->state]);
/* /*
If we're being called from a handler it means we're in the conetxt of teh tcpip If we're being called from a handler it means we're in the conetxt of teh tcpip
@ -726,15 +751,16 @@ LibTCPClose(struct tcp_pcb *pcb, const int safe)
*/ */
if (safe) if (safe)
{ {
if (pcb->state == LISTEN) CloseCallbacks((PTCP_PCB)Connection->SocketContext);
if ( ((PTCP_PCB)Connection->SocketContext)->state == LISTEN )
{ {
DbgPrint("[lwIP, LibTCPClose] Closing a listener\n"); DbgPrint("[lwIP, LibTCPClose] Closing a listener\n");
ret = tcp_close(pcb); ret = tcp_close((PTCP_PCB)Connection->SocketContext);
} }
else else
{ {
DbgPrint("[lwIP, LibTCPClose] Aborting a connection\n"); DbgPrint("[lwIP, LibTCPClose] Aborting a connection\n");
tcp_abort(pcb); tcp_abort((PTCP_PCB)Connection->SocketContext);
ret = ERR_OK; ret = ERR_OK;
} }
@ -749,7 +775,7 @@ LibTCPClose(struct tcp_pcb *pcb, const int safe)
{ {
KeInitializeEvent(&msg->Event, NotificationEvent, FALSE); KeInitializeEvent(&msg->Event, NotificationEvent, FALSE);
msg->Pcb = pcb; msg->Connection = Connection;
tcpip_callback_with_block(LibTCPCloseCallback, msg, 1); tcpip_callback_with_block(LibTCPCloseCallback, msg, 1);
@ -772,7 +798,7 @@ LibTCPClose(struct tcp_pcb *pcb, const int safe)
} }
void void
LibTCPAccept(struct tcp_pcb *pcb, struct tcp_pcb *listen_pcb, void *arg) LibTCPAccept(PTCP_PCB pcb, struct tcp_pcb *listen_pcb, void *arg)
{ {
DbgPrint("[lwIP, LibTCPAccept] Called. (pcb, arg) = (0x%x, 0x%x)\n", pcb, arg); DbgPrint("[lwIP, LibTCPAccept] Called. (pcb, arg) = (0x%x, 0x%x)\n", pcb, arg);
@ -790,7 +816,7 @@ LibTCPAccept(struct tcp_pcb *pcb, struct tcp_pcb *listen_pcb, void *arg)
} }
err_t err_t
LibTCPGetHostName(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, u16_t *const port) LibTCPGetHostName(PTCP_PCB pcb, struct ip_addr *const ipaddr, u16_t *const port)
{ {
DbgPrint("[lwIP, LibTCPGetHostName] Called. pcb = (0x%x)\n", pcb); DbgPrint("[lwIP, LibTCPGetHostName] Called. pcb = (0x%x)\n", pcb);
@ -808,7 +834,7 @@ LibTCPGetHostName(struct tcp_pcb *pcb, struct ip_addr *const ipaddr, u16_t *cons
} }
err_t err_t
LibTCPGetPeerName(struct tcp_pcb *pcb, struct ip_addr * const ipaddr, u16_t * const port) LibTCPGetPeerName(PTCP_PCB pcb, struct ip_addr * const ipaddr, u16_t * const port)
{ {
DbgPrint("[lwIP, LibTCPGetPeerName] pcb = (0x%x)\n", pcb); DbgPrint("[lwIP, LibTCPGetPeerName] pcb = (0x%x)\n", pcb);