apply changes r52501 for lwIP branch

- Fix binding to an unspecified port on a connect so that it works reliably by asking the TCP library for a free port instead of assuming that one we have is free
- Fix binding to an unspecified port on a listen which previously would result in the address file not having information stored about the port number assigned
- Fix a nasty bug which resulted in us binding to an arbitrary port during a connect even when the client wanted a specific port

svn path=/branches/GSoC_2011/TcpIpDriver/; revision=52517
This commit is contained in:
Claudiu Mihail 2011-07-03 13:34:36 +00:00
parent 75014f14eb
commit 1b907a202d
6 changed files with 190 additions and 129 deletions

View file

@ -739,7 +739,8 @@ NTSTATUS DispTdiQueryInformation(
Parameters = (PTDI_REQUEST_KERNEL_QUERY_INFORMATION)&IrpSp->Parameters; Parameters = (PTDI_REQUEST_KERNEL_QUERY_INFORMATION)&IrpSp->Parameters;
TranContext = IrpSp->FileObject->FsContext; TranContext = IrpSp->FileObject->FsContext;
if (!TranContext) { if (!TranContext)
{
TI_DbgPrint(MID_TRACE, ("Bad transport context.\n")); TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
@ -753,10 +754,10 @@ NTSTATUS DispTdiQueryInformation(
PTA_IP_ADDRESS Address; PTA_IP_ADDRESS Address;
PCONNECTION_ENDPOINT Endpoint = NULL; PCONNECTION_ENDPOINT Endpoint = NULL;
if (MmGetMdlByteCount(Irp->MdlAddress) < if (MmGetMdlByteCount(Irp->MdlAddress) <
(FIELD_OFFSET(TDI_ADDRESS_INFO, Address.Address[0].Address) + (FIELD_OFFSET(TDI_ADDRESS_INFO, Address.Address[0].Address) +
sizeof(TDI_ADDRESS_IP))) { sizeof(TDI_ADDRESS_IP)))
{
TI_DbgPrint(MID_TRACE, ("MDL buffer too small.\n")); TI_DbgPrint(MID_TRACE, ("MDL buffer too small.\n"));
return STATUS_BUFFER_TOO_SMALL; return STATUS_BUFFER_TOO_SMALL;
} }
@ -764,7 +765,8 @@ NTSTATUS DispTdiQueryInformation(
AddressInfo = (PTDI_ADDRESS_INFO)MmGetSystemAddressForMdl(Irp->MdlAddress); AddressInfo = (PTDI_ADDRESS_INFO)MmGetSystemAddressForMdl(Irp->MdlAddress);
Address = (PTA_IP_ADDRESS)&AddressInfo->Address; Address = (PTA_IP_ADDRESS)&AddressInfo->Address;
switch ((ULONG_PTR)IrpSp->FileObject->FsContext2) { switch ((ULONG_PTR)IrpSp->FileObject->FsContext2)
{
case TDI_TRANSPORT_ADDRESS_FILE: case TDI_TRANSPORT_ADDRESS_FILE:
AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle; AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle;
@ -779,12 +781,19 @@ NTSTATUS DispTdiQueryInformation(
return STATUS_SUCCESS; return STATUS_SUCCESS;
case TDI_CONNECTION_FILE: case TDI_CONNECTION_FILE:
Endpoint = Endpoint = (PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
(PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
Address->TAAddressCount = 1;
Address->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP;
Address->Address[0].AddressType = TDI_ADDRESS_TYPE_IP;
Address->Address[0].Address[0].sin_port = Endpoint->AddressFile->Port;
Address->Address[0].Address[0].in_addr = Endpoint->AddressFile->Address.Address.IPv4Address;
RtlZeroMemory( RtlZeroMemory(
&Address->Address[0].Address[0].sin_zero, &Address->Address[0].Address[0].sin_zero,
sizeof(Address->Address[0].Address[0].sin_zero)); sizeof(Address->Address[0].Address[0].sin_zero));
return TCPGetSockAddress( Endpoint, (PTRANSPORT_ADDRESS)Address, FALSE );
return STATUS_SUCCESS;
default: default:
TI_DbgPrint(MIN_TRACE, ("Invalid transport context\n")); TI_DbgPrint(MIN_TRACE, ("Invalid transport context\n"));
@ -808,15 +817,15 @@ NTSTATUS DispTdiQueryInformation(
AddressInfo = (PTDI_CONNECTION_INFORMATION) AddressInfo = (PTDI_CONNECTION_INFORMATION)
MmGetSystemAddressForMdl(Irp->MdlAddress); MmGetSystemAddressForMdl(Irp->MdlAddress);
switch ((ULONG_PTR)IrpSp->FileObject->FsContext2) { switch ((ULONG_PTR)IrpSp->FileObject->FsContext2)
{
case TDI_TRANSPORT_ADDRESS_FILE: case TDI_TRANSPORT_ADDRESS_FILE:
AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle; AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle;
Endpoint = AddrFile ? AddrFile->Connection : NULL; Endpoint = AddrFile ? AddrFile->Connection : NULL;
break; break;
case TDI_CONNECTION_FILE: case TDI_CONNECTION_FILE:
Endpoint = Endpoint = (PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
(PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
break; break;
default: default:
@ -824,7 +833,8 @@ NTSTATUS DispTdiQueryInformation(
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
if (!Endpoint) { if (!Endpoint)
{
TI_DbgPrint(MID_TRACE, ("No connection object.\n")); TI_DbgPrint(MID_TRACE, ("No connection object.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }

View file

@ -199,7 +199,10 @@ VOID AddrFileFree(
switch (AddrFile->Protocol) switch (AddrFile->Protocol)
{ {
case IPPROTO_TCP: case IPPROTO_TCP:
TCPFreePort( AddrFile->Port ); if (AddrFile->Port)
{
TCPFreePort(AddrFile->Port);
}
break; break;
case IPPROTO_UDP: case IPPROTO_UDP:
@ -289,17 +292,27 @@ NTSTATUS FileOpenAddress(
switch (Protocol) switch (Protocol)
{ {
case IPPROTO_TCP: case IPPROTO_TCP:
AddrFile->Port = if (Address->Address[0].Address[0].sin_port)
TCPAllocatePort(Address->Address[0].Address[0].sin_port); {
/* The client specified an explicit port so we force a bind to this */
AddrFile->Port = TCPAllocatePort(Address->Address[0].Address[0].sin_port);
if ((Address->Address[0].Address[0].sin_port && /* Check for bind success */
AddrFile->Port != Address->Address[0].Address[0].sin_port) || if (AddrFile->Port == 0xffff)
AddrFile->Port == 0xffff)
{ {
ExFreePoolWithTag(AddrFile, ADDR_FILE_TAG); ExFreePoolWithTag(AddrFile, ADDR_FILE_TAG);
return STATUS_ADDRESS_ALREADY_EXISTS; return STATUS_ADDRESS_ALREADY_EXISTS;
} }
/* Sanity check */
ASSERT(Address->Address[0].Address[0].sin_port == AddrFile->Port);
}
else
{
/* The client wants an unspecified port so we wait to see what the TCP library gives us */
AddrFile->Port = 0;
}
AddEntity(CO_TL_ENTITY, AddrFile, CO_TL_TCP); AddEntity(CO_TL_ENTITY, AddrFile, CO_TL_TCP);
AddrFile->Send = NULL; /* TCPSendData */ AddrFile->Send = NULL; /* TCPSendData */

View file

@ -51,12 +51,14 @@ NTSTATUS TCPListen(PCONNECTION_ENDPOINT Connection, UINT Backlog)
NTSTATUS Status = STATUS_SUCCESS; NTSTATUS Status = STATUS_SUCCESS;
struct ip_addr AddressToBind; struct ip_addr AddressToBind;
KIRQL OldIrql; KIRQL OldIrql;
TA_IP_ADDRESS LocalAddress;
ASSERT(Connection); ASSERT(Connection);
ASSERT_KM_POINTER(Connection->AddressFile);
LockObject(Connection, &OldIrql); LockObject(Connection, &OldIrql);
ASSERT_KM_POINTER(Connection->AddressFile);
TI_DbgPrint(DEBUG_TCP,("[IP, TCPListen] Called\n")); TI_DbgPrint(DEBUG_TCP,("[IP, TCPListen] Called\n"));
DbgPrint("[IP, TCPListen] Called\n"); DbgPrint("[IP, TCPListen] Called\n");
@ -69,6 +71,24 @@ NTSTATUS TCPListen(PCONNECTION_ENDPOINT Connection, UINT Backlog)
&AddressToBind, &AddressToBind,
Connection->AddressFile->Port)); Connection->AddressFile->Port));
if (NT_SUCCESS(Status))
{
/* Check if we had an unspecified port */
if (!Connection->AddressFile->Port)
{
/* We did, so we need to copy back the port */
Status = TCPGetSockAddress(Connection, (PTRANSPORT_ADDRESS)&LocalAddress, FALSE);
if (NT_SUCCESS(Status))
{
/* Allocate the port in the port bitmap */
Connection->AddressFile->Port = TCPAllocatePort(LocalAddress.Address[0].Address[0].sin_port);
/* This should never fail */
ASSERT(Connection->AddressFile->Port != 0xFFFF);
}
}
}
if (NT_SUCCESS(Status)) if (NT_SUCCESS(Status))
{ {
Connection->SocketContext = LibTCPListen(Connection->SocketContext, Backlog); Connection->SocketContext = LibTCPListen(Connection->SocketContext, Backlog);

View file

@ -212,7 +212,7 @@ TCPAcceptEventHandler(void *arg, struct tcp_pcb *newpcb)
/* sanity assert...this should never be in anything else but a CLOSED state */ /* sanity assert...this should never be in anything else but a CLOSED state */
ASSERT(((struct tcp_pcb*)OldSocketContext)->state == CLOSED); ASSERT(((struct tcp_pcb*)OldSocketContext)->state == CLOSED);
/* free socket context created in FileOpenConnection, as we're using a new /* free socket context created in FileOpenConnection, as we're using a new
one; we free it asynchornously because otherwise we create a dedlock */ one; we free it asynchornously because otherwise we create a deadlock */
ChewCreate(SocketContextCloseWorker, OldSocketContext); ChewCreate(SocketContextCloseWorker, OldSocketContext);
} }

View file

@ -39,17 +39,22 @@ TCPSendDataCallback(struct netif *netif, struct pbuf *p, struct ip_addr *dest)
} }
else else
{ {
DbgPrint("[IP, TCPSendDataCallback] FAIL EINVAL 1\n");
return EINVAL; return EINVAL;
} }
DbgPrint("[IP, TCPSendDataCallback] Set packet local and remore adresses\n");
if (!(NCE = RouteGetRouteToDestination(&RemoteAddress))) if (!(NCE = RouteGetRouteToDestination(&RemoteAddress)))
{ {
DbgPrint("[IP, TCPSendDataCallback] FAIL EINVAL 2\n");
return EINVAL; return EINVAL;
} }
NdisStatus = AllocatePacketWithBuffer(&Packet.NdisPacket, NULL, p->tot_len); NdisStatus = AllocatePacketWithBuffer(&Packet.NdisPacket, NULL, p->tot_len);
if (NdisStatus != NDIS_STATUS_SUCCESS) if (NdisStatus != NDIS_STATUS_SUCCESS)
{ {
DbgPrint("[IP, TCPSendDataCallback] FAIL ENOBUFS\n");
return ENOBUFS; return ENOBUFS;
} }
@ -61,6 +66,8 @@ TCPSendDataCallback(struct netif *netif, struct pbuf *p, struct ip_addr *dest)
RtlCopyMemory(((PUCHAR)Packet.Header) + i, p1->payload, p1->len); RtlCopyMemory(((PUCHAR)Packet.Header) + i, p1->payload, p1->len);
} }
DbgPrint("[IP, TCPSendDataCallback] Allocated NDIS packet and set data\n");
Packet.HeaderSize = sizeof(IPv4_HEADER); Packet.HeaderSize = sizeof(IPv4_HEADER);
Packet.TotalSize = p->tot_len; Packet.TotalSize = p->tot_len;
Packet.SrcAddr = LocalAddress; Packet.SrcAddr = LocalAddress;
@ -68,10 +75,13 @@ TCPSendDataCallback(struct netif *netif, struct pbuf *p, struct ip_addr *dest)
if (!NT_SUCCESS(IPSendDatagram(&Packet, NCE, TCPPacketSendComplete, NULL))) if (!NT_SUCCESS(IPSendDatagram(&Packet, NCE, TCPPacketSendComplete, NULL)))
{ {
DbgPrint("[IP, TCPSendDataCallback] FAIL EINVAL 3\n");
FreeNdisPacket(Packet.NdisPacket); FreeNdisPacket(Packet.NdisPacket);
return EINVAL; return EINVAL;
} }
DbgPrint("[IP, TCPSendDataCallback] Leaving\n");
return 0; return 0;
} }

View file

@ -110,11 +110,12 @@ NTSTATUS TCPClose
PVOID Socket; PVOID Socket;
LockObject(Connection, &OldIrql); LockObject(Connection, &OldIrql);
DbgPrint("[IP, TCPClose] Called for Connection( 0x%x )->SocketConext( 0x%x )\n", Connection, Connection->SocketContext);
Socket = Connection->SocketContext; Socket = Connection->SocketContext;
Connection->SocketContext = NULL; Connection->SocketContext = NULL;
DbgPrint("[IP, TCPClose] Called\n");
/* We should not be associated to an address file at this point */ /* We should not be associated to an address file at this point */
ASSERT(!Connection->AddressFile); ASSERT(!Connection->AddressFile);
@ -122,6 +123,9 @@ NTSTATUS TCPClose
if (Socket) if (Socket)
{ {
FlushAllQueues(Connection, STATUS_CANCELLED); FlushAllQueues(Connection, STATUS_CANCELLED);
DbgPrint("[IP, TCPClose] Socket (pcb) = 0x%x\n", Socket);
LibTCPClose(Socket); LibTCPClose(Socket);
} }
@ -247,6 +251,7 @@ NTSTATUS TCPConnect
struct ip_addr bindaddr, connaddr; struct ip_addr bindaddr, connaddr;
IP_ADDRESS RemoteAddress; IP_ADDRESS RemoteAddress;
USHORT RemotePort; USHORT RemotePort;
TA_IP_ADDRESS LocalAddress;
PTDI_BUCKET Bucket; PTDI_BUCKET Bucket;
PNEIGHBOR_CACHE_ENTRY NCE; PNEIGHBOR_CACHE_ENTRY NCE;
KIRQL OldIrql; KIRQL OldIrql;
@ -299,6 +304,23 @@ NTSTATUS TCPConnect
DbgPrint("LibTCPBind: 0x%x\n", Status); DbgPrint("LibTCPBind: 0x%x\n", Status);
if (NT_SUCCESS(Status))
{
/* Check if we had an unspecified port */
if (!Connection->AddressFile->Port)
{
/* We did, so we need to copy back the port */
Status = TCPGetSockAddress(Connection, (PTRANSPORT_ADDRESS)&LocalAddress, FALSE);
if (NT_SUCCESS(Status))
{
/* Allocate the port in the port bitmap */
Connection->AddressFile->Port = TCPAllocatePort(LocalAddress.Address[0].Address[0].sin_port);
/* This should never fail */
ASSERT(Connection->AddressFile->Port != 0xFFFF);
}
}
if (NT_SUCCESS(Status)) if (NT_SUCCESS(Status))
{ {
connaddr.addr = RemoteAddress.Address.IPv4Address; connaddr.addr = RemoteAddress.Address.IPv4Address;
@ -320,20 +342,6 @@ NTSTATUS TCPConnect
RemotePort)); RemotePort));
DbgPrint("LibTCPConnect: 0x%x\n", Status); DbgPrint("LibTCPConnect: 0x%x\n", Status);
if (Status == STATUS_PENDING)
{
/*Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), TDI_BUCKET_TAG );
if( !Bucket )
{
UnlockObject(Connection, OldIrql);
return STATUS_NO_MEMORY;
}
Bucket->Request.RequestNotifyObject = (PVOID)Complete;
Bucket->Request.RequestContext = Context;
InsertTailList( &Connection->ConnectRequest, &Bucket->Entry );*/
} }
} }