- Fix broken handling of partial receives

svn path=/trunk/; revision=57169
This commit is contained in:
Cameron Gutman 2012-08-26 22:24:49 +00:00
parent 44356bb773
commit aa6611286c
2 changed files with 50 additions and 31 deletions

View file

@ -15,6 +15,7 @@ typedef struct tcp_pcb* PTCP_PCB;
typedef struct _QUEUE_ENTRY
{
struct pbuf *p;
ULONG Offset;
LIST_ENTRY ListEntry;
} QUEUE_ENTRY, *PQUEUE_ENTRY;

View file

@ -60,6 +60,7 @@ void LibTCPEnqueuePacket(PCONNECTION_ENDPOINT Connection, struct pbuf *p)
qp = (PQUEUE_ENTRY)ExAllocateFromNPagedLookasideList(&QueueEntryLookasideList);
qp->p = p;
qp->Offset = 0;
ExInterlockedInsertTailList(&Connection->PacketQueue, &qp->ListEntry, &Connection->Lock);
}
@ -82,9 +83,10 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
{
PQUEUE_ENTRY qp;
struct pbuf* p;
NTSTATUS Status = STATUS_PENDING;
UINT ReadLength, ExistingDataLength;
NTSTATUS Status;
UINT ReadLength, PayloadLength;
KIRQL OldIrql;
PUCHAR Payload;
(*Received) = 0;
@ -95,50 +97,54 @@ NTSTATUS LibTCPGetDataFromConnectionQueue(PCONNECTION_ENDPOINT Connection, PUCHA
while ((qp = LibTCPDequeuePacket(Connection)) != NULL)
{
p = qp->p;
ExistingDataLength = (*Received);
Status = STATUS_SUCCESS;
/* Calculate the payload first */
Payload = p->payload;
Payload += qp->Offset;
PayloadLength = p->len;
PayloadLength -= qp->Offset;
ReadLength = MIN(p->tot_len, RecvLen);
if (ReadLength != p->tot_len)
/* Check if we're reading the whole buffer */
ReadLength = MIN(PayloadLength, RecvLen);
if (ReadLength != PayloadLength)
{
if (ExistingDataLength)
{
/* The packet was too big but we used some data already so give it another shot later */
/* Save this one for later */
qp->Offset += ReadLength;
InsertHeadList(&Connection->PacketQueue, &qp->ListEntry);
break;
}
else
{
/* The packet is just too big to fit fully in our buffer, even when empty so
* return an informative status but still copy all the data we can fit.
*/
Status = STATUS_BUFFER_OVERFLOW;
}
qp = NULL;
}
UnlockObject(Connection, OldIrql);
/* Return to a lower IRQL because the receive buffer may be pageable memory */
for (; (*Received) < ReadLength + ExistingDataLength; (*Received) += p->len, p = p->next)
{
RtlCopyMemory(RecvBuffer + (*Received), p->payload, p->len);
}
RtlCopyMemory(RecvBuffer,
Payload,
ReadLength);
LockObject(Connection, &OldIrql);
/* Update trackers */
RecvLen -= ReadLength;
RecvBuffer += ReadLength;
(*Received) += ReadLength;
if (qp != NULL)
{
/* Use this special pbuf free callback function because we're outside tcpip thread */
pbuf_free_callback(qp->p);
ExFreeToNPagedLookasideList(&QueueEntryLookasideList, qp);
}
else
{
/* If we get here, it means we've filled the buffer */
ASSERT(RecvLen == 0);
}
Status = STATUS_SUCCESS;
if (!RecvLen)
break;
if (Status != STATUS_SUCCESS)
break;
}
}
else
@ -196,6 +202,8 @@ err_t
InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t err)
{
PCONNECTION_ENDPOINT Connection = arg;
struct pbuf *pb;
ULONG RecvLen;
/* Make sure the socket didn't get closed */
if (!arg)
@ -208,9 +216,19 @@ InternalRecvEventHandler(void *arg, PTCP_PCB pcb, struct pbuf *p, const err_t er
if (p)
{
LibTCPEnqueuePacket(Connection, p);
pb = p;
RecvLen = 0;
while (pb != NULL)
{
/* Enqueue this buffer */
LibTCPEnqueuePacket(Connection, pb);
RecvLen += pb->len;
tcp_recved(pcb, p->tot_len);
/* Advance and unchain the buffer */
pb = pbuf_dechain(pb);;
}
tcp_recved(pcb, RecvLen);
TCPRecvEventHandler(arg);
}