- Rewrite a large portion of the send code to have proper support for overlapped and non-blocking sockets

svn path=/trunk/; revision=57187
This commit is contained in:
Cameron Gutman 2012-08-28 00:14:08 +00:00
parent f65fc45ba2
commit 083efd4ac3

View file

@ -21,6 +21,8 @@ static NTSTATUS NTAPI SendComplete
PAFD_SEND_INFO SendReq = NULL;
PAFD_MAPBUF Map;
UINT TotalBytesCopied = 0, TotalBytesProcessed = 0, SpaceAvail, i;
UINT SendLength, BytesCopied;
BOOLEAN HaltSendQueue;
/*
* The Irp parameter passed in is the IRP of the stream between AFD and
@ -94,27 +96,48 @@ static NTSTATUS NTAPI SendComplete
}
RtlMoveMemory( FCB->Send.Window,
FCB->Send.Window + FCB->Send.BytesUsed,
FCB->Send.Window + Irp->IoStatus.Information,
FCB->Send.BytesUsed - Irp->IoStatus.Information );
TotalBytesProcessed = 0;
while (!IsListEmpty(&FCB->PendingIrpList[FUNCTION_SEND]) &&
TotalBytesProcessed != Irp->IoStatus.Information) {
SendLength = Irp->IoStatus.Information;
HaltSendQueue = FALSE;
while (!IsListEmpty(&FCB->PendingIrpList[FUNCTION_SEND]) && SendLength > 0) {
NextIrpEntry = RemoveHeadList(&FCB->PendingIrpList[FUNCTION_SEND]);
NextIrp = CONTAINING_RECORD(NextIrpEntry, IRP, Tail.Overlay.ListEntry);
NextIrpSp = IoGetCurrentIrpStackLocation( NextIrp );
SendReq = GetLockedData(NextIrp, NextIrpSp);
Map = (PAFD_MAPBUF)(SendReq->BufferArray + SendReq->BufferCount);
TotalBytesCopied = 0;
TotalBytesCopied = (ULONG_PTR)NextIrp->Tail.Overlay.DriverContext[3];
ASSERT(TotalBytesCopied != 0);
for( i = 0; i < SendReq->BufferCount; i++ )
TotalBytesCopied += SendReq->BufferArray[i].len;
/* If we didn't get enough, keep waiting */
if (TotalBytesCopied > SendLength)
{
/* Update the bytes left to copy */
TotalBytesCopied -= SendLength;
NextIrp->Tail.Overlay.DriverContext[3] = (PVOID)TotalBytesCopied;
/* Update the state variables */
FCB->Send.BytesUsed -= SendLength;
TotalBytesProcessed += SendLength;
SendLength = 0;
/* Pend the IRP */
InsertHeadList(&FCB->PendingIrpList[FUNCTION_SEND],
&NextIrp->Tail.Overlay.ListEntry);
HaltSendQueue = TRUE;
break;
}
ASSERT(NextIrp->IoStatus.Information != 0);
NextIrp->IoStatus.Status = Irp->IoStatus.Status;
NextIrp->IoStatus.Information = TotalBytesCopied;
FCB->Send.BytesUsed -= TotalBytesCopied;
TotalBytesProcessed += TotalBytesCopied;
SendLength -= TotalBytesCopied;
(void)IoSetCancelRoutine(NextIrp, NULL);
@ -127,11 +150,9 @@ static NTSTATUS NTAPI SendComplete
IoCompleteRequest(NextIrp, IO_NETWORK_INCREMENT);
}
ASSERT(TotalBytesProcessed == Irp->IoStatus.Information);
ASSERT(SendLength == 0);
FCB->Send.BytesUsed -= TotalBytesProcessed;
while( !IsListEmpty( &FCB->PendingIrpList[FUNCTION_SEND] ) ) {
while( !HaltSendQueue && !IsListEmpty( &FCB->PendingIrpList[FUNCTION_SEND] ) ) {
NextIrpEntry = RemoveHeadList(&FCB->PendingIrpList[FUNCTION_SEND]);
NextIrp = CONTAINING_RECORD(NextIrpEntry, IRP, Tail.Overlay.ListEntry);
NextIrpSp = IoGetCurrentIrpStackLocation( NextIrp );
@ -143,37 +164,62 @@ static NTSTATUS NTAPI SendComplete
SpaceAvail = FCB->Send.Size - FCB->Send.BytesUsed;
TotalBytesCopied = 0;
/* Count the total transfer size */
SendLength = 0;
for (i = 0; i < SendReq->BufferCount; i++)
{
SendLength += SendReq->BufferArray[i].len;
}
/* Make sure we've got the space */
if (SendLength > SpaceAvail)
{
/* Blocking sockets have to wait here */
if (SendLength <= FCB->Send.Size && !((SendReq->AfdFlags & AFD_IMMEDIATE) || (FCB->NonBlocking)))
{
FCB->PollState &= ~AFD_EVENT_SEND;
InsertHeadList(&FCB->PendingIrpList[FUNCTION_SEND],
&NextIrp->Tail.Overlay.ListEntry);
NextIrp = NULL;
}
/* Check if we can send anything */
if (SpaceAvail == 0)
{
FCB->PollState &= ~AFD_EVENT_SEND;
/* We should never be non-overlapped and get to this point */
ASSERT(SendReq->AfdFlags & AFD_OVERLAPPED);
InsertHeadList(&FCB->PendingIrpList[FUNCTION_SEND],
&NextIrp->Tail.Overlay.ListEntry);
NextIrp = NULL;
}
}
if (NextIrp == NULL)
break;
for( i = 0; i < SendReq->BufferCount; i++ ) {
if (SpaceAvail < SendReq->BufferArray[i].len)
{
InsertHeadList(&FCB->PendingIrpList[FUNCTION_SEND],
&NextIrp->Tail.Overlay.ListEntry);
NextIrp = NULL;
break;
}
BytesCopied = MIN(SendReq->BufferArray[i].len, SpaceAvail);
Map[i].BufferAddress =
MmMapLockedPages( Map[i].Mdl, KernelMode );
RtlCopyMemory( FCB->Send.Window + FCB->Send.BytesUsed,
Map[i].BufferAddress,
SendReq->BufferArray[i].len );
BytesCopied );
MmUnmapLockedPages( Map[i].BufferAddress, Map[i].Mdl );
TotalBytesCopied += SendReq->BufferArray[i].len;
SpaceAvail -= SendReq->BufferArray[i].len;
TotalBytesCopied += BytesCopied;
SpaceAvail -= BytesCopied;
FCB->Send.BytesUsed += BytesCopied;
}
if (NextIrp != NULL)
{
FCB->Send.BytesUsed += TotalBytesCopied;
}
else
break;
}
if (FCB->Send.Size - FCB->Send.BytesUsed != 0 &&
!FCB->SendClosed)
if (FCB->Send.Size - FCB->Send.BytesUsed != 0 && !FCB->SendClosed &&
IsListEmpty(&FCB->PendingIrpList[FUNCTION_SEND]))
{
FCB->PollState |= AFD_EVENT_SEND;
FCB->PollStatus[FD_WRITE_BIT] = STATUS_SUCCESS;
@ -184,6 +230,7 @@ static NTSTATUS NTAPI SendComplete
FCB->PollState &= ~AFD_EVENT_SEND;
}
/* Some data is still waiting */
if( FCB->Send.BytesUsed )
{
@ -279,13 +326,14 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
PFILE_OBJECT FileObject = IrpSp->FileObject;
PAFD_FCB FCB = FileObject->FsContext;
PAFD_SEND_INFO SendReq;
UINT TotalBytesCopied = 0, i, SpaceAvail = 0;
BOOLEAN NoSpace = FALSE;
UINT TotalBytesCopied = 0, i, SpaceAvail = 0, BytesCopied, SendLength;
AFD_DbgPrint(MID_TRACE,("Called on %x\n", FCB));
if( !SocketAcquireStateLock( FCB ) ) return LostSocket( Irp );
FCB->EventSelectDisabled &= ~AFD_EVENT_SEND;
if( FCB->Flags & AFD_ENDPOINT_CONNECTIONLESS )
{
PAFD_SEND_INFO_UDP SendReq;
@ -316,7 +364,6 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
Status = TdiBuildConnectionInfo( &TargetAddress, FCB->RemoteAddress );
if( NT_SUCCESS(Status) ) {
FCB->EventSelectDisabled &= ~AFD_EVENT_SEND;
FCB->PollState &= ~AFD_EVENT_SEND;
Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND);
@ -371,7 +418,7 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
if( !(SendReq = LockRequest( Irp, IrpSp, FALSE )) )
return UnlockAndMaybeComplete
( FCB, STATUS_NO_MEMORY, Irp, TotalBytesCopied );
( FCB, STATUS_NO_MEMORY, Irp, 0 );
SendReq->BufferArray = LockBuffers( SendReq->BufferArray,
SendReq->BufferCount,
@ -405,37 +452,62 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
AFD_DbgPrint(MID_TRACE,("We can accept %d bytes\n",
SpaceAvail));
for( i = 0; FCB->Send.BytesUsed < FCB->Send.Size &&
i < SendReq->BufferCount; i++ ) {
/* Count the total transfer size */
SendLength = 0;
for (i = 0; i < SendReq->BufferCount; i++)
{
SendLength += SendReq->BufferArray[i].len;
}
if (SpaceAvail < SendReq->BufferArray[i].len)
/* Make sure we've got the space */
if (SendLength > SpaceAvail)
{
/* Blocking sockets have to wait here */
if (SendLength <= FCB->Send.Size && !((SendReq->AfdFlags & AFD_IMMEDIATE) || (FCB->NonBlocking)))
{
if (TotalBytesCopied + SendReq->BufferArray[i].len > FCB->Send.Size)
{
UnlockBuffers(SendReq->BufferArray, SendReq->BufferCount, FALSE);
return UnlockAndMaybeComplete(FCB, STATUS_BUFFER_OVERFLOW, Irp, 0);
}
SpaceAvail += TotalBytesCopied;
NoSpace = TRUE;
break;
FCB->PollState &= ~AFD_EVENT_SEND;
return LeaveIrpUntilLater(FCB, Irp, FUNCTION_SEND);
}
/* Check if we can send anything */
if (SpaceAvail == 0)
{
FCB->PollState &= ~AFD_EVENT_SEND;
/* Non-overlapped sockets will fail if we can send nothing */
if (!(SendReq->AfdFlags & AFD_OVERLAPPED))
{
UnlockBuffers( SendReq->BufferArray, SendReq->BufferCount, FALSE );
return UnlockAndMaybeComplete( FCB, STATUS_CANT_WAIT, Irp, 0 );
}
else
{
/* Overlapped sockets just pend */
return LeaveIrpUntilLater(FCB, Irp, FUNCTION_SEND);
}
}
}
for ( i = 0; SpaceAvail > 0 && i < SendReq->BufferCount; i++ )
{
BytesCopied = MIN(SendReq->BufferArray[i].len, SpaceAvail);
AFD_DbgPrint(MID_TRACE,("Copying Buffer %d, %x:%d to %x\n",
i,
SendReq->BufferArray[i].buf,
SendReq->BufferArray[i].len,
BytesCopied,
FCB->Send.Window + FCB->Send.BytesUsed));
RtlCopyMemory( FCB->Send.Window + FCB->Send.BytesUsed,
RtlCopyMemory(FCB->Send.Window + FCB->Send.BytesUsed,
SendReq->BufferArray[i].buf,
SendReq->BufferArray[i].len );
BytesCopied);
TotalBytesCopied += SendReq->BufferArray[i].len;
SpaceAvail -= SendReq->BufferArray[i].len;
TotalBytesCopied += BytesCopied;
SpaceAvail -= BytesCopied;
FCB->Send.BytesUsed += BytesCopied;
}
FCB->EventSelectDisabled &= ~AFD_EVENT_SEND;
Irp->IoStatus.Information = TotalBytesCopied;
if( TotalBytesCopied == 0 ) {
AFD_DbgPrint(MID_TRACE,("Empty send\n"));
@ -455,40 +527,25 @@ AfdConnectedSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
FCB->PollState &= ~AFD_EVENT_SEND;
}
if (!NoSpace)
{
FCB->Send.BytesUsed += TotalBytesCopied;
AFD_DbgPrint(MID_TRACE,("Copied %d bytes\n", TotalBytesCopied));
/* We use the IRP tail for some temporary storage here */
Irp->Tail.Overlay.DriverContext[3] = (PVOID)Irp->IoStatus.Information;
Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND);
if (Status == STATUS_PENDING && !FCB->SendIrp.InFlightRequest)
{
TdiSend(&FCB->SendIrp.InFlightRequest,
FCB->Connection.Object,
0,
FCB->Send.Window,
FCB->Send.BytesUsed,
&FCB->SendIrp.Iosb,
SendComplete,
FCB);
}
SocketStateUnlock(FCB);
return STATUS_PENDING;
}
else
Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND);
if (Status == STATUS_PENDING && !FCB->SendIrp.InFlightRequest)
{
FCB->PollState &= ~AFD_EVENT_SEND;
if (!(SendReq->AfdFlags & AFD_OVERLAPPED) &&
((SendReq->AfdFlags & AFD_IMMEDIATE) || (FCB->NonBlocking))) {
AFD_DbgPrint(MID_TRACE,("Nonblocking\n"));
UnlockBuffers( SendReq->BufferArray, SendReq->BufferCount, FALSE );
return UnlockAndMaybeComplete( FCB, STATUS_CANT_WAIT, Irp, 0 );
} else {
AFD_DbgPrint(MID_TRACE,("Queuing request\n"));
return LeaveIrpUntilLater( FCB, Irp, FUNCTION_SEND );
}
TdiSend(&FCB->SendIrp.InFlightRequest,
FCB->Connection.Object,
0,
FCB->Send.Window,
FCB->Send.BytesUsed,
&FCB->SendIrp.Iosb,
SendComplete,
FCB);
}
SocketStateUnlock(FCB);
return STATUS_PENDING;
}
NTSTATUS NTAPI
@ -504,6 +561,8 @@ AfdPacketSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
if( !SocketAcquireStateLock( FCB ) ) return LostSocket( Irp );
FCB->EventSelectDisabled &= ~AFD_EVENT_SEND;
/* Check that the socket is bound */
if( FCB->State != SOCKET_STATE_BOUND &&
FCB->State != SOCKET_STATE_CREATED)
@ -562,7 +621,6 @@ AfdPacketSocketWriteData(PDEVICE_OBJECT DeviceObject, PIRP Irp,
/* Check the size of the Address given ... */
if( NT_SUCCESS(Status) ) {
FCB->EventSelectDisabled &= ~AFD_EVENT_SEND;
FCB->PollState &= ~AFD_EVENT_SEND;
Status = QueueUserModeIrp(FCB, Irp, FUNCTION_SEND);