- Disassociate the socket context before closing so we don't get signalled with a freed context (major cause of crashes)

- Signal the connection directly from TCPSocketState
 - Remove some unused code
 - Queue completion requests on a separate list so we don't have to keep locking and unlocking while completing
 - Add better locking to tcpip (not the lib)
 - Remove some unused variables
 - Don't hold the cancel spin lock longer than necessary
 - Check that we successfully got the device description

svn path=/trunk/; revision=44250
This commit is contained in:
Cameron Gutman 2009-11-21 13:00:37 +00:00
parent fda7cf929a
commit c71cdeb49d
10 changed files with 216 additions and 172 deletions

View file

@ -991,6 +991,11 @@ BOOLEAN BindAdapter(
GetName( RegistryPath, &IF->Name ); GetName( RegistryPath, &IF->Name );
Status = FindDeviceDescForAdapter( &IF->Name, &IF->Description ); Status = FindDeviceDescForAdapter( &IF->Name, &IF->Description );
if (!NT_SUCCESS(Status)) {
TI_DbgPrint(MIN_TRACE, ("Failed to get device description.\n"));
IPDestroyInterface(IF);
return FALSE;
}
TI_DbgPrint(DEBUG_DATALINK,("Adapter Description: %wZ\n", TI_DbgPrint(DEBUG_DATALINK,("Adapter Description: %wZ\n",
&IF->Description)); &IF->Description));

View file

@ -59,6 +59,11 @@ typedef struct _SLEEPING_THREAD {
KEVENT Event; KEVENT Event;
} SLEEPING_THREAD, *PSLEEPING_THREAD; } SLEEPING_THREAD, *PSLEEPING_THREAD;
typedef struct _CLIENT_DATA {
BOOLEAN Unlocked;
KSPIN_LOCK Lock;
} CLIENT_DATA, *PCLIENT_DATA;
/* Retransmission timeout constants */ /* Retransmission timeout constants */
/* Lower bound for retransmission timeout in TCP timer ticks */ /* Lower bound for retransmission timeout in TCP timer ticks */
@ -84,6 +89,7 @@ typedef struct _SLEEPING_THREAD {
#define SRF_FIN TCP_FIN #define SRF_FIN TCP_FIN
extern LONG TCP_IPIdentification; extern LONG TCP_IPIdentification;
extern CLIENT_DATA ClientInfo;
/* accept.c */ /* accept.c */
NTSTATUS TCPServiceListeningSocket( PCONNECTION_ENDPOINT Listener, NTSTATUS TCPServiceListeningSocket( PCONNECTION_ENDPOINT Listener,
@ -106,6 +112,8 @@ VOID TCPFreeConnectionEndpoint( PCONNECTION_ENDPOINT Connection );
NTSTATUS TCPSocket( PCONNECTION_ENDPOINT Connection, NTSTATUS TCPSocket( PCONNECTION_ENDPOINT Connection,
UINT Family, UINT Type, UINT Proto ); UINT Family, UINT Type, UINT Proto );
VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection);
PTCP_SEGMENT TCPCreateSegment( PTCP_SEGMENT TCPCreateSegment(
PIP_PACKET IPPacket, PIP_PACKET IPPacket,
PTCPv4_HEADER TCPHeader, PTCPv4_HEADER TCPHeader,
@ -156,8 +164,6 @@ NTSTATUS TCPClose( PCONNECTION_ENDPOINT Connection );
NTSTATUS TCPTranslateError( int OskitError ); NTSTATUS TCPTranslateError( int OskitError );
VOID TCPTimeout();
UINT TCPAllocatePort( UINT HintPort ); UINT TCPAllocatePort( UINT HintPort );
VOID TCPFreePort( UINT Port ); VOID TCPFreePort( UINT Port );

View file

@ -263,6 +263,8 @@ typedef struct _TDI_BUCKET {
LIST_ENTRY Entry; LIST_ENTRY Entry;
struct _CONNECTION_ENDPOINT *AssociatedEndpoint; struct _CONNECTION_ENDPOINT *AssociatedEndpoint;
TDI_REQUEST Request; TDI_REQUEST Request;
NTSTATUS Status;
ULONG Information;
} TDI_BUCKET, *PTDI_BUCKET; } TDI_BUCKET, *PTDI_BUCKET;
/* Transport connection context structure A.K.A. Transmission Control Block /* Transport connection context structure A.K.A. Transmission Control Block
@ -280,6 +282,7 @@ typedef struct _CONNECTION_ENDPOINT {
LIST_ENTRY ListenRequest; /* Queued listen requests */ LIST_ENTRY ListenRequest; /* Queued listen requests */
LIST_ENTRY ReceiveRequest; /* Queued receive requests */ LIST_ENTRY ReceiveRequest; /* Queued receive requests */
LIST_ENTRY SendRequest; /* Queued send requests */ LIST_ENTRY SendRequest; /* Queued send requests */
LIST_ENTRY CompletionQueue;/* Completed requests to finish */
/* Signals */ /* Signals */
UINT SignalState; /* Active signals from oskit */ UINT SignalState; /* Active signals from oskit */

View file

@ -13,6 +13,8 @@ TDI_STATUS SetAddressFileInfo(TDIObjectID *ID,
PVOID Buffer, PVOID Buffer,
UINT BufferSize) UINT BufferSize)
{ {
//KIRQL OldIrql;
switch (ID->toi_id) switch (ID->toi_id)
{ {
#if 0 #if 0
@ -20,7 +22,10 @@ TDI_STATUS SetAddressFileInfo(TDIObjectID *ID,
if (BufferSize < sizeof(UCHAR)) if (BufferSize < sizeof(UCHAR))
return TDI_INVALID_PARAMETER; return TDI_INVALID_PARAMETER;
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql);
AddrFile->TTL = *((PUCHAR)Buffer); AddrFile->TTL = *((PUCHAR)Buffer);
KeReleaseSpinLock(&AddrFile->Lock, OldIrql);
return TDI_SUCCESS; return TDI_SUCCESS;
#endif #endif
default: default:

View file

@ -72,7 +72,6 @@ VOID DispDataRequestComplete(
{ {
PIRP Irp; PIRP Irp;
PIO_STACK_LOCATION IrpSp; PIO_STACK_LOCATION IrpSp;
PTRANSPORT_CONTEXT TranContext;
KIRQL OldIrql; KIRQL OldIrql;
TI_DbgPrint(DEBUG_IRP, ("Called for irp %x (%x, %d).\n", TI_DbgPrint(DEBUG_IRP, ("Called for irp %x (%x, %d).\n",
@ -80,7 +79,6 @@ VOID DispDataRequestComplete(
Irp = Context; Irp = Context;
IrpSp = IoGetCurrentIrpStackLocation(Irp); IrpSp = IoGetCurrentIrpStackLocation(Irp);
TranContext = (PTRANSPORT_CONTEXT)IrpSp->FileObject->FsContext;
IoAcquireCancelSpinLock(&OldIrql); IoAcquireCancelSpinLock(&OldIrql);
@ -117,6 +115,8 @@ VOID NTAPI DispCancelRequest(
PFILE_OBJECT FileObject; PFILE_OBJECT FileObject;
UCHAR MinorFunction; UCHAR MinorFunction;
IoReleaseCancelSpinLock(Irp->CancelIrql);
TI_DbgPrint(DEBUG_IRP, ("Called.\n")); TI_DbgPrint(DEBUG_IRP, ("Called.\n"));
IrpSp = IoGetCurrentIrpStackLocation(Irp); IrpSp = IoGetCurrentIrpStackLocation(Irp);
@ -168,7 +168,6 @@ VOID NTAPI DispCancelRequest(
break; break;
} }
IoReleaseCancelSpinLock(Irp->CancelIrql);
IRPFinish(Irp, STATUS_CANCELLED); IRPFinish(Irp, STATUS_CANCELLED);
TI_DbgPrint(MAX_TRACE, ("Leaving.\n")); TI_DbgPrint(MAX_TRACE, ("Leaving.\n"));
@ -191,6 +190,8 @@ VOID NTAPI DispCancelListenRequest(
PCONNECTION_ENDPOINT Connection; PCONNECTION_ENDPOINT Connection;
/*NTSTATUS Status = STATUS_SUCCESS;*/ /*NTSTATUS Status = STATUS_SUCCESS;*/
IoReleaseCancelSpinLock(Irp->CancelIrql);
TI_DbgPrint(DEBUG_IRP, ("Called.\n")); TI_DbgPrint(DEBUG_IRP, ("Called.\n"));
IrpSp = IoGetCurrentIrpStackLocation(Irp); IrpSp = IoGetCurrentIrpStackLocation(Irp);
@ -213,8 +214,6 @@ VOID NTAPI DispCancelListenRequest(
TCPAbortListenForSocket(Connection->AddressFile->Listener, TCPAbortListenForSocket(Connection->AddressFile->Listener,
Connection); Connection);
IoReleaseCancelSpinLock(Irp->CancelIrql);
Irp->IoStatus.Information = 0; Irp->IoStatus.Information = 0;
IRPFinish(Irp, STATUS_CANCELLED); IRPFinish(Irp, STATUS_CANCELLED);
@ -255,6 +254,7 @@ NTSTATUS DispTdiAssociateAddress(
PFILE_OBJECT FileObject; PFILE_OBJECT FileObject;
PADDRESS_FILE AddrFile = NULL; PADDRESS_FILE AddrFile = NULL;
NTSTATUS Status; NTSTATUS Status;
KIRQL OldIrql;
TI_DbgPrint(DEBUG_IRP, ("Called.\n")); TI_DbgPrint(DEBUG_IRP, ("Called.\n"));
@ -274,11 +274,6 @@ NTSTATUS DispTdiAssociateAddress(
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
if (Connection->AddressFile) {
TI_DbgPrint(MID_TRACE, ("An address file is already asscociated.\n"));
return STATUS_INVALID_PARAMETER;
}
Parameters = (PTDI_REQUEST_KERNEL_ASSOCIATE)&IrpSp->Parameters; Parameters = (PTDI_REQUEST_KERNEL_ASSOCIATE)&IrpSp->Parameters;
Status = ObReferenceObjectByHandle( Status = ObReferenceObjectByHandle(
@ -294,8 +289,18 @@ NTSTATUS DispTdiAssociateAddress(
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
if (Connection->AddressFile) {
ObDereferenceObject(FileObject);
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(MID_TRACE, ("An address file is already asscociated.\n"));
return STATUS_INVALID_PARAMETER;
}
if (FileObject->FsContext2 != (PVOID)TDI_TRANSPORT_ADDRESS_FILE) { if (FileObject->FsContext2 != (PVOID)TDI_TRANSPORT_ADDRESS_FILE) {
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(MID_TRACE, ("Bad address file object. Magic (0x%X).\n", TI_DbgPrint(MID_TRACE, ("Bad address file object. Magic (0x%X).\n",
FileObject->FsContext2)); FileObject->FsContext2));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
@ -306,17 +311,21 @@ NTSTATUS DispTdiAssociateAddress(
TranContext = FileObject->FsContext; TranContext = FileObject->FsContext;
if (!TranContext) { if (!TranContext) {
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(MID_TRACE, ("Bad transport context.\n")); TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle; AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle;
if (!AddrFile) { if (!AddrFile) {
KeReleaseSpinLock(&Connection->Lock, OldIrql);
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
TI_DbgPrint(MID_TRACE, ("No address file object.\n")); TI_DbgPrint(MID_TRACE, ("No address file object.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLockAtDpcLevel(&AddrFile->Lock);
Connection->AddressFile = AddrFile; Connection->AddressFile = AddrFile;
/* Add connection endpoint to the address file */ /* Add connection endpoint to the address file */
@ -325,6 +334,9 @@ NTSTATUS DispTdiAssociateAddress(
/* FIXME: Maybe do this in DispTdiDisassociateAddress() instead? */ /* FIXME: Maybe do this in DispTdiDisassociateAddress() instead? */
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
KeReleaseSpinLockFromDpcLevel(&AddrFile->Lock);
KeReleaseSpinLock(&Connection->Lock, OldIrql);
return Status; return Status;
} }
@ -405,6 +417,7 @@ NTSTATUS DispTdiDisassociateAddress(
PCONNECTION_ENDPOINT Connection; PCONNECTION_ENDPOINT Connection;
PTRANSPORT_CONTEXT TranContext; PTRANSPORT_CONTEXT TranContext;
PIO_STACK_LOCATION IrpSp; PIO_STACK_LOCATION IrpSp;
KIRQL OldIrql;
TI_DbgPrint(DEBUG_IRP, ("Called.\n")); TI_DbgPrint(DEBUG_IRP, ("Called.\n"));
@ -424,17 +437,26 @@ NTSTATUS DispTdiDisassociateAddress(
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
if (!Connection->AddressFile) { if (!Connection->AddressFile) {
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(MID_TRACE, ("No address file is asscociated.\n")); TI_DbgPrint(MID_TRACE, ("No address file is asscociated.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLockAtDpcLevel(&Connection->AddressFile->Lock);
/* Remove this connection from the address file */ /* Remove this connection from the address file */
Connection->AddressFile->Connection = NULL; Connection->AddressFile->Connection = NULL;
KeReleaseSpinLockFromDpcLevel(&Connection->AddressFile->Lock);
/* Remove the address file from this connection */ /* Remove the address file from this connection */
Connection->AddressFile = NULL; Connection->AddressFile = NULL;
KeReleaseSpinLock(&Connection->Lock, OldIrql);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
@ -511,6 +533,7 @@ NTSTATUS DispTdiListen(
PTRANSPORT_CONTEXT TranContext; PTRANSPORT_CONTEXT TranContext;
PIO_STACK_LOCATION IrpSp; PIO_STACK_LOCATION IrpSp;
NTSTATUS Status = STATUS_SUCCESS; NTSTATUS Status = STATUS_SUCCESS;
KIRQL OldIrql;
TI_DbgPrint(DEBUG_IRP, ("Called.\n")); TI_DbgPrint(DEBUG_IRP, ("Called.\n"));
@ -536,15 +559,23 @@ NTSTATUS DispTdiListen(
Parameters = (PTDI_REQUEST_KERNEL)&IrpSp->Parameters; Parameters = (PTDI_REQUEST_KERNEL)&IrpSp->Parameters;
TI_DbgPrint(MIN_TRACE, ("Connection->AddressFile: %x\n",
Connection->AddressFile ));
ASSERT(Connection->AddressFile);
Status = DispPrepareIrpForCancel Status = DispPrepareIrpForCancel
(TranContext->Handle.ConnectionContext, (TranContext->Handle.ConnectionContext,
Irp, Irp,
(PDRIVER_CANCEL)DispCancelListenRequest); (PDRIVER_CANCEL)DispCancelListenRequest);
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
if (Connection->AddressFile == NULL)
{
TI_DbgPrint(MID_TRACE, ("No associated address file\n"));
KeReleaseSpinLock(&Connection->Lock, OldIrql);
Status = STATUS_INVALID_PARAMETER;
goto done;
}
KeAcquireSpinLockAtDpcLevel(&Connection->AddressFile->Lock);
/* Listening will require us to create a listening socket and store it in /* Listening will require us to create a listening socket and store it in
* the address file. It will be signalled, and attempt to complete an irp * the address file. It will be signalled, and attempt to complete an irp
* when a new connection arrives. */ * when a new connection arrives. */
@ -581,6 +612,9 @@ NTSTATUS DispTdiListen(
Irp ); Irp );
} }
KeReleaseSpinLockFromDpcLevel(&Connection->AddressFile->Lock);
KeReleaseSpinLock(&Connection->Lock, OldIrql);
done: done:
if (Status != STATUS_PENDING) { if (Status != STATUS_PENDING) {
DispDataRequestComplete(Irp, Status, 0); DispDataRequestComplete(Irp, Status, 0);
@ -657,12 +691,10 @@ NTSTATUS DispTdiQueryInformation(
case TDI_CONNECTION_FILE: case TDI_CONNECTION_FILE:
Endpoint = Endpoint =
(PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext; (PCONNECTION_ENDPOINT)TranContext->Handle.ConnectionContext;
TCPGetSockAddress( Endpoint, (PTRANSPORT_ADDRESS)Address, FALSE );
DbgPrint("Returning socket address %x\n", Address->Address[0].Address[0].in_addr);
RtlZeroMemory( RtlZeroMemory(
&Address->Address[0].Address[0].sin_zero, &Address->Address[0].Address[0].sin_zero,
sizeof(Address->Address[0].Address[0].sin_zero)); sizeof(Address->Address[0].Address[0].sin_zero));
return STATUS_SUCCESS; return TCPGetSockAddress( Endpoint, (PTRANSPORT_ADDRESS)Address, FALSE );
default: default:
TI_DbgPrint(MIN_TRACE, ("Invalid transport context\n")); TI_DbgPrint(MIN_TRACE, ("Invalid transport context\n"));
@ -990,8 +1022,10 @@ NTSTATUS DispTdiSendDatagram(
DataBuffer, DataBuffer,
BufferSize, BufferSize,
&Irp->IoStatus.Information); &Irp->IoStatus.Information);
else else {
Status = STATUS_UNSUCCESSFUL; Status = STATUS_UNSUCCESSFUL;
ASSERT(FALSE);
}
} }
done: done:
@ -1199,11 +1233,10 @@ VOID DispTdiQueryInformationExComplete(
*/ */
{ {
PTI_QUERY_CONTEXT QueryContext; PTI_QUERY_CONTEXT QueryContext;
UINT Count = 0;
QueryContext = (PTI_QUERY_CONTEXT)Context; QueryContext = (PTI_QUERY_CONTEXT)Context;
if (NT_SUCCESS(Status)) { if (NT_SUCCESS(Status)) {
Count = CopyBufferToBufferChain( CopyBufferToBufferChain(
QueryContext->InputMdl, QueryContext->InputMdl,
FIELD_OFFSET(TCP_REQUEST_QUERY_INFORMATION_EX, Context), FIELD_OFFSET(TCP_REQUEST_QUERY_INFORMATION_EX, Context),
(PCHAR)&QueryContext->QueryInfo.Context, (PCHAR)&QueryContext->QueryInfo.Context,

View file

@ -322,7 +322,6 @@ TiDispatchOpenClose(
{ {
PIO_STACK_LOCATION IrpSp; PIO_STACK_LOCATION IrpSp;
NTSTATUS Status; NTSTATUS Status;
PTRANSPORT_CONTEXT Context;
IRPRemember(Irp, __FILE__, __LINE__); IRPRemember(Irp, __FILE__, __LINE__);
@ -338,8 +337,7 @@ TiDispatchOpenClose(
/* Close an address file, connection endpoint, or control connection */ /* Close an address file, connection endpoint, or control connection */
case IRP_MJ_CLOSE: case IRP_MJ_CLOSE:
Context = (PTRANSPORT_CONTEXT)IrpSp->FileObject->FsContext; Status = TiCloseFileObject(DeviceObject, Irp);
Status = TiCloseFileObject(DeviceObject, Irp);
break; break;
default: default:

View file

@ -73,9 +73,6 @@ void NTAPI IPTimeout( PVOID Context ) {
/* Clean possible outdated cached neighbor addresses */ /* Clean possible outdated cached neighbor addresses */
NBTimeout(); NBTimeout();
/* Call upper layer timeout routines */
TCPTimeout();
} }

View file

@ -29,7 +29,8 @@ int TCPSocketState(void *ClientData,
return 0; return 0;
} }
KeAcquireSpinLockAtDpcLevel(&Connection->Lock); if (ClientInfo.Unlocked)
KeAcquireSpinLockAtDpcLevel(&Connection->Lock);
TI_DbgPrint(DEBUG_TCP,("Called: NewState %x (Conn %x) (Change %x)\n", TI_DbgPrint(DEBUG_TCP,("Called: NewState %x (Conn %x) (Change %x)\n",
NewState, Connection, NewState, Connection,
@ -38,7 +39,10 @@ int TCPSocketState(void *ClientData,
Connection->SignalState |= NewState; Connection->SignalState |= NewState;
KeReleaseSpinLockFromDpcLevel(&Connection->Lock); HandleSignalledConnection(Connection);
if (ClientInfo.Unlocked)
KeReleaseSpinLockFromDpcLevel(&Connection->Lock);
return 0; return 0;
} }

View file

@ -16,30 +16,38 @@ LONG TCP_IPIdentification = 0;
static BOOLEAN TCPInitialized = FALSE; static BOOLEAN TCPInitialized = FALSE;
static NPAGED_LOOKASIDE_LIST TCPSegmentList; static NPAGED_LOOKASIDE_LIST TCPSegmentList;
PORT_SET TCPPorts; PORT_SET TCPPorts;
CLIENT_DATA ClientInfo;
static VOID DrainSignals() { static VOID
PCONNECTION_ENDPOINT Connection; ProcessCompletions(PCONNECTION_ENDPOINT Connection)
PLIST_ENTRY CurrentEntry, NextEntry; {
KIRQL OldIrql; PLIST_ENTRY CurrentEntry;
NTSTATUS Status = STATUS_SUCCESS;
PTCP_COMPLETION_ROUTINE Complete;
PTDI_BUCKET Bucket; PTDI_BUCKET Bucket;
PLIST_ENTRY Entry; PTCP_COMPLETION_ROUTINE Complete;
PIRP Irp;
PMDL Mdl;
ULONG SocketError;
KeAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql); while ((CurrentEntry = ExInterlockedRemoveHeadList(&Connection->CompletionQueue,
CurrentEntry = ConnectionEndpointListHead.Flink; &Connection->Lock)))
while (CurrentEntry != &ConnectionEndpointListHead)
{ {
NextEntry = CurrentEntry->Flink; Bucket = CONTAINING_RECORD(CurrentEntry, TDI_BUCKET, Entry);
KeReleaseSpinLock(&ConnectionEndpointListLock, OldIrql); Complete = Bucket->Request.RequestNotifyObject;
Connection = CONTAINING_RECORD( CurrentEntry, CONNECTION_ENDPOINT, Complete(Bucket->Request.RequestContext, Bucket->Status, Bucket->Information);
ListEntry );
KeAcquireSpinLock(&Connection->Lock, &OldIrql); exFreePool(Bucket);
}
if (!Connection->SocketContext)
TCPFreeConnectionEndpoint(Connection);
}
VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection)
{
PTDI_BUCKET Bucket;
PLIST_ENTRY Entry;
NTSTATUS Status;
PIRP Irp;
PMDL Mdl;
ULONG SocketError;
TI_DbgPrint(MID_TRACE,("Handling signalled state on %x (%x)\n", TI_DbgPrint(MID_TRACE,("Handling signalled state on %x (%x)\n",
Connection, Connection->SocketContext)); Connection, Connection->SocketContext));
@ -47,8 +55,6 @@ static VOID DrainSignals() {
if( !Connection->SocketContext || Connection->SignalState & SEL_FIN ) { if( !Connection->SocketContext || Connection->SignalState & SEL_FIN ) {
TI_DbgPrint(DEBUG_TCP, ("EOF From socket\n")); TI_DbgPrint(DEBUG_TCP, ("EOF From socket\n"));
Connection->SignalState = 0;
/* If OskitTCP initiated the disconnect, try to read the socket error that occurred */ /* If OskitTCP initiated the disconnect, try to read the socket error that occurred */
if (Connection->SocketContext) if (Connection->SocketContext)
SocketError = TCPTranslateError(OskitTCPGetSocketError(Connection->SocketContext)); SocketError = TCPTranslateError(OskitTCPGetSocketError(Connection->SocketContext));
@ -57,140 +63,124 @@ static VOID DrainSignals() {
if (!Connection->SocketContext || !SocketError) if (!Connection->SocketContext || !SocketError)
SocketError = STATUS_CANCELLED; SocketError = STATUS_CANCELLED;
KeReleaseSpinLock(&Connection->Lock, OldIrql); while (!IsListEmpty(&Connection->ReceiveRequest))
while ((Entry = ExInterlockedRemoveHeadList( &Connection->ReceiveRequest,
&Connection->Lock )) != NULL)
{ {
Entry = RemoveHeadList( &Connection->ReceiveRequest );
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
Complete( Bucket->Request.RequestContext, SocketError, 0 ); Bucket->Status = SocketError;
Bucket->Information = 0;
exFreePool(Bucket); InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
while ((Entry = ExInterlockedRemoveHeadList( &Connection->SendRequest, while (!IsListEmpty(&Connection->SendRequest))
&Connection->Lock )) != NULL)
{ {
Entry = RemoveHeadList( &Connection->SendRequest );
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
Complete( Bucket->Request.RequestContext, SocketError, 0 ); Bucket->Status = SocketError;
Bucket->Information = 0;
exFreePool(Bucket); InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
while ((Entry = ExInterlockedRemoveHeadList( &Connection->ListenRequest, while (!IsListEmpty(&Connection->ListenRequest))
&Connection->Lock )) != NULL)
{ {
Entry = RemoveHeadList( &Connection->ListenRequest );
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
Complete( Bucket->Request.RequestContext, SocketError, 0 ); Bucket->Status = SocketError;
Bucket->Information = 0;
exFreePool(Bucket); InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
while ((Entry = ExInterlockedRemoveHeadList( &Connection->ConnectRequest, while (!IsListEmpty(&Connection->ConnectRequest))
&Connection->Lock )) != NULL)
{ {
Entry = RemoveHeadList( &Connection->ConnectRequest );
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
Complete( Bucket->Request.RequestContext, SocketError, 0 ); Bucket->Status = SocketError;
Bucket->Information = 0;
exFreePool(Bucket); InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql); Connection->SignalState = 0;
} }
/* Things that can happen when we try the initial connection */ /* Things that can happen when we try the initial connection */
if( Connection->SignalState & SEL_CONNECT ) { if( Connection->SignalState & SEL_CONNECT ) {
KeReleaseSpinLock(&Connection->Lock, OldIrql); while (!IsListEmpty(&Connection->ConnectRequest)) {
while( (Entry = ExInterlockedRemoveHeadList( &Connection->ConnectRequest, Entry = RemoveHeadList( &Connection->ConnectRequest );
&Connection->Lock )) != NULL ) {
TI_DbgPrint(DEBUG_TCP, ("Connect Event\n"));
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
TI_DbgPrint(DEBUG_TCP,
("Completing Request %x\n", Bucket->Request.RequestContext));
Complete( Bucket->Request.RequestContext, STATUS_SUCCESS, 0 ); Bucket->Status = STATUS_SUCCESS;
Bucket->Information = 0;
/* Frees the bucket allocated in TCPConnect */ InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
exFreePool( Bucket );
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
} }
if( Connection->SignalState & SEL_ACCEPT ) { if( Connection->SignalState & SEL_ACCEPT ) {
/* Handle readable on a listening socket -- /* Handle readable on a listening socket --
* TODO: Implement filtering * TODO: Implement filtering
*/ */
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(DEBUG_TCP,("Accepting new connection on %x (Queue: %s)\n", TI_DbgPrint(DEBUG_TCP,("Accepting new connection on %x (Queue: %s)\n",
Connection, Connection,
IsListEmpty(&Connection->ListenRequest) ? IsListEmpty(&Connection->ListenRequest) ?
"empty" : "nonempty")); "empty" : "nonempty"));
while( (Entry = ExInterlockedRemoveHeadList( &Connection->ListenRequest, while (!IsListEmpty(&Connection->ListenRequest)) {
&Connection->Lock )) != NULL ) {
PIO_STACK_LOCATION IrpSp; PIO_STACK_LOCATION IrpSp;
Entry = RemoveHeadList( &Connection->ListenRequest );
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
Irp = Bucket->Request.RequestContext; Irp = Bucket->Request.RequestContext;
IrpSp = IoGetCurrentIrpStackLocation( Irp ); IrpSp = IoGetCurrentIrpStackLocation( Irp );
TI_DbgPrint(DEBUG_TCP,("Getting the socket\n")); TI_DbgPrint(DEBUG_TCP,("Getting the socket\n"));
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
Status = TCPServiceListeningSocket Status = TCPServiceListeningSocket
( Connection->AddressFile->Listener, ( Connection->AddressFile->Listener,
Bucket->AssociatedEndpoint, Bucket->AssociatedEndpoint,
(PTDI_REQUEST_KERNEL)&IrpSp->Parameters ); (PTDI_REQUEST_KERNEL)&IrpSp->Parameters );
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(DEBUG_TCP,("Socket: Status: %x\n")); TI_DbgPrint(DEBUG_TCP,("Socket: Status: %x\n"));
if( Status == STATUS_PENDING ) { if( Status == STATUS_PENDING ) {
ExInterlockedInsertHeadList( &Connection->ListenRequest, &Bucket->Entry, InsertHeadList( &Connection->ListenRequest, &Bucket->Entry );
&Connection->Lock );
break; break;
} else { } else {
Complete( Bucket->Request.RequestContext, Status, 0 ); Bucket->Status = Status;
exFreePool( Bucket ); Bucket->Information = 0;
InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
} }
/* Things that happen after we're connected */ /* Things that happen after we're connected */
if( Connection->SignalState & SEL_READ && if( Connection->SignalState & SEL_READ ) {
Connection->SignalState & SEL_CONNECT ) {
TI_DbgPrint(DEBUG_TCP,("Readable: irp list %s\n", TI_DbgPrint(DEBUG_TCP,("Readable: irp list %s\n",
IsListEmpty(&Connection->ReceiveRequest) ? IsListEmpty(&Connection->ReceiveRequest) ?
"empty" : "nonempty")); "empty" : "nonempty"));
KeReleaseSpinLock(&Connection->Lock, OldIrql); while (!IsListEmpty(&Connection->ReceiveRequest)) {
while( (Entry = ExInterlockedRemoveHeadList( &Connection->ReceiveRequest,
&Connection->Lock )) != NULL ) {
OSK_UINT RecvLen = 0, Received = 0; OSK_UINT RecvLen = 0, Received = 0;
PVOID RecvBuffer = 0; PVOID RecvBuffer = 0;
Entry = RemoveHeadList( &Connection->ReceiveRequest );
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
Irp = Bucket->Request.RequestContext; Irp = Bucket->Request.RequestContext;
Mdl = Irp->MdlAddress; Mdl = Irp->MdlAddress;
@ -210,8 +200,6 @@ static VOID DrainSignals() {
Connection->SocketContext)); Connection->SocketContext));
TI_DbgPrint(DEBUG_TCP, ("RecvBuffer: %x\n", RecvBuffer)); TI_DbgPrint(DEBUG_TCP, ("RecvBuffer: %x\n", RecvBuffer));
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
Status = TCPTranslateError Status = TCPTranslateError
( OskitTCPRecv( Connection->SocketContext, ( OskitTCPRecv( Connection->SocketContext,
RecvBuffer, RecvBuffer,
@ -219,46 +207,35 @@ static VOID DrainSignals() {
&Received, &Received,
0 ) ); 0 ) );
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(DEBUG_TCP,("TCP Bytes: %d\n", Received)); TI_DbgPrint(DEBUG_TCP,("TCP Bytes: %d\n", Received));
if( Status == STATUS_SUCCESS ) { if( Status == STATUS_PENDING ) {
TI_DbgPrint(DEBUG_TCP,("Received %d bytes with status %x\n", InsertHeadList( &Connection->ReceiveRequest, &Bucket->Entry );
Received, Status));
Complete( Bucket->Request.RequestContext,
STATUS_SUCCESS, Received );
exFreePool( Bucket );
} else if( Status == STATUS_PENDING ) {
ExInterlockedInsertHeadList( &Connection->ReceiveRequest, &Bucket->Entry,
&Connection->Lock );
break; break;
} else { } else {
TI_DbgPrint(DEBUG_TCP, TI_DbgPrint(DEBUG_TCP,
("Completing Receive request: %x %x\n", ("Completing Receive request: %x %x\n",
Bucket->Request, Status)); Bucket->Request, Status));
Complete( Bucket->Request.RequestContext, Status, 0 );
exFreePool( Bucket ); Bucket->Status = Status;
Bucket->Information = (Status == STATUS_SUCCESS) ? Received : 0;
InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
} }
if( Connection->SignalState & SEL_WRITE && if( Connection->SignalState & SEL_WRITE ) {
Connection->SignalState & SEL_CONNECT ) {
TI_DbgPrint(DEBUG_TCP,("Writeable: irp list %s\n", TI_DbgPrint(DEBUG_TCP,("Writeable: irp list %s\n",
IsListEmpty(&Connection->SendRequest) ? IsListEmpty(&Connection->SendRequest) ?
"empty" : "nonempty")); "empty" : "nonempty"));
KeReleaseSpinLock(&Connection->Lock, OldIrql); while (!IsListEmpty(&Connection->SendRequest)) {
while( (Entry = ExInterlockedRemoveHeadList( &Connection->SendRequest,
&Connection->Lock )) != NULL ) {
OSK_UINT SendLen = 0, Sent = 0; OSK_UINT SendLen = 0, Sent = 0;
PVOID SendBuffer = 0; PVOID SendBuffer = 0;
Entry = RemoveHeadList( &Connection->SendRequest );
Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry ); Bucket = CONTAINING_RECORD( Entry, TDI_BUCKET, Entry );
Complete = Bucket->Request.RequestNotifyObject;
Irp = Bucket->Request.RequestContext; Irp = Bucket->Request.RequestContext;
Mdl = Irp->MdlAddress; Mdl = Irp->MdlAddress;
@ -277,8 +254,6 @@ static VOID DrainSignals() {
("Connection->SocketContext: %x\n", ("Connection->SocketContext: %x\n",
Connection->SocketContext)); Connection->SocketContext));
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
Status = TCPTranslateError Status = TCPTranslateError
( OskitTCPSend( Connection->SocketContext, ( OskitTCPSend( Connection->SocketContext,
SendBuffer, SendBuffer,
@ -286,43 +261,55 @@ static VOID DrainSignals() {
&Sent, &Sent,
0 ) ); 0 ) );
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(DEBUG_TCP,("TCP Bytes: %d\n", Sent)); TI_DbgPrint(DEBUG_TCP,("TCP Bytes: %d\n", Sent));
if( Status == STATUS_SUCCESS ) { if( Status == STATUS_PENDING ) {
TI_DbgPrint(DEBUG_TCP,("Sent %d bytes with status %x\n", InsertHeadList( &Connection->SendRequest, &Bucket->Entry );
Sent, Status));
Complete( Bucket->Request.RequestContext,
STATUS_SUCCESS, Sent );
exFreePool( Bucket );
} else if( Status == STATUS_PENDING ) {
ExInterlockedInsertHeadList( &Connection->SendRequest, &Bucket->Entry,
&Connection->Lock );
break; break;
} else { } else {
TI_DbgPrint(DEBUG_TCP, TI_DbgPrint(DEBUG_TCP,
("Completing Send request: %x %x\n", ("Completing Send request: %x %x\n",
Bucket->Request, Status)); Bucket->Request, Status));
Complete( Bucket->Request.RequestContext, Status, 0 );
exFreePool( Bucket ); Bucket->Status = Status;
Bucket->Information = (Status == STATUS_SUCCESS) ? Sent : 0;
InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql);
} }
}
KeReleaseSpinLock(&Connection->Lock, OldIrql); static
VOID DrainSignals(VOID) {
PCONNECTION_ENDPOINT Connection;
PLIST_ENTRY CurrentEntry;
KIRQL OldIrql;
if (!Connection->SocketContext) KeAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql);
{ CurrentEntry = ConnectionEndpointListHead.Flink;
TCPFreeConnectionEndpoint(Connection); while (CurrentEntry != &ConnectionEndpointListHead)
} {
Connection = CONTAINING_RECORD( CurrentEntry, CONNECTION_ENDPOINT,
ListEntry );
CurrentEntry = CurrentEntry->Flink;
KeReleaseSpinLock(&ConnectionEndpointListLock, OldIrql);
CurrentEntry = NextEntry; KeAcquireSpinLock(&Connection->Lock, &OldIrql);
if (Connection->SocketContext)
{
HandleSignalledConnection(Connection);
KeReleaseSpinLock(&Connection->Lock, OldIrql);
KeAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql); ProcessCompletions(Connection);
}
else
{
KeReleaseSpinLock(&Connection->Lock, OldIrql);
}
KeAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql);
} }
KeReleaseSpinLock(&ConnectionEndpointListLock, OldIrql); KeReleaseSpinLock(&ConnectionEndpointListLock, OldIrql);
} }
@ -342,6 +329,7 @@ PCONNECTION_ENDPOINT TCPAllocateConnectionEndpoint( PVOID ClientContext ) {
InitializeListHead(&Connection->ListenRequest); InitializeListHead(&Connection->ListenRequest);
InitializeListHead(&Connection->ReceiveRequest); InitializeListHead(&Connection->ReceiveRequest);
InitializeListHead(&Connection->SendRequest); InitializeListHead(&Connection->SendRequest);
InitializeListHead(&Connection->CompletionQueue);
/* Save client context pointer */ /* Save client context pointer */
Connection->ClientContext = ClientContext; Connection->ClientContext = ClientContext;
@ -402,13 +390,21 @@ VOID TCPReceive(PIP_INTERFACE Interface, PIP_PACKET IPPacket)
* This is the low level interface for receiving TCP data * This is the low level interface for receiving TCP data
*/ */
{ {
KIRQL OldIrql;
TI_DbgPrint(DEBUG_TCP,("Sending packet %d (%d) to oskit\n", TI_DbgPrint(DEBUG_TCP,("Sending packet %d (%d) to oskit\n",
IPPacket->TotalSize, IPPacket->TotalSize,
IPPacket->HeaderSize)); IPPacket->HeaderSize));
KeAcquireSpinLock(&ClientInfo.Lock, &OldIrql);
ClientInfo.Unlocked = TRUE;
OskitTCPReceiveDatagram( IPPacket->Header, OskitTCPReceiveDatagram( IPPacket->Header,
IPPacket->TotalSize, IPPacket->TotalSize,
IPPacket->HeaderSize ); IPPacket->HeaderSize );
ClientInfo.Unlocked = FALSE;
KeReleaseSpinLock(&ClientInfo.Lock, OldIrql);
} }
/* event.c */ /* event.c */
@ -467,7 +463,7 @@ TimerThread(PVOID Context)
while ( 1 ) { while ( 1 ) {
if (Next == NextFast) { if (Next == NextFast) {
NextFast += 2; NextFast += 2;
} }
if (Next == NextSlow) { if (Next == NextSlow) {
NextSlow += 5; NextSlow += 5;
} }
@ -480,9 +476,7 @@ TimerThread(PVOID Context)
} }
TimerOskitTCP( Next == NextFast, Next == NextSlow ); TimerOskitTCP( Next == NextFast, Next == NextSlow );
if (Next == NextSlow) { DrainSignals();
DrainSignals();
}
Current = Next; Current = Next;
if (10 <= Current) { if (10 <= Current) {
@ -502,7 +496,6 @@ StartTimer(VOID)
TimerThread, NULL); TimerThread, NULL);
} }
NTSTATUS TCPStartup(VOID) NTSTATUS TCPStartup(VOID)
/* /*
* FUNCTION: Initializes the TCP subsystem * FUNCTION: Initializes the TCP subsystem
@ -523,6 +516,9 @@ NTSTATUS TCPStartup(VOID)
return Status; return Status;
} }
KeInitializeSpinLock(&ClientInfo.Lock);
ClientInfo.Unlocked = FALSE;
RegisterOskitTCPEventHandlers( &EventHandlers ); RegisterOskitTCPEventHandlers( &EventHandlers );
InitOskitTCP(); InitOskitTCP();
@ -862,10 +858,6 @@ NTSTATUS TCPSendData
return Status; return Status;
} }
VOID TCPTimeout(VOID) {
/* Now handled by TimerThread */
}
UINT TCPAllocatePort( UINT HintPort ) { UINT TCPAllocatePort( UINT HintPort ) {
if( HintPort ) { if( HintPort ) {
if( AllocatePort( &TCPPorts, HintPort ) ) return HintPort; if( AllocatePort( &TCPPorts, HintPort ) ) return HintPort;

View file

@ -285,11 +285,16 @@ int OskitTCPShutdown( void *socket, int disconn_type ) {
int OskitTCPClose( void *socket ) { int OskitTCPClose( void *socket ) {
int error; int error;
struct socket *so = socket;
if (!socket) if (!socket)
return OSK_ESHUTDOWN; return OSK_ESHUTDOWN;
OSKLock(); OSKLock();
/* We have to remove the socket context here otherwise we end up
* back in HandleSignalledConnection with a freed connection context
*/
so->so_connection = NULL;
error = soclose( socket ); error = soclose( socket );
OSKUnlock(); OSKUnlock();
@ -435,16 +440,12 @@ void OskitTCPReceiveDatagram( OSK_PCHAR Data, OSK_UINT Len,
OSK_UINT IpHeaderLen ) { OSK_UINT IpHeaderLen ) {
struct mbuf *Ip; struct mbuf *Ip;
struct ip *iph; struct ip *iph;
KIRQL OldIrql;
/* This function is a special case in which we cannot use OSKLock/OSKUnlock OSKLock();
* because we don't enter with the connection lock held */
OSKLockAndRaise(&OldIrql);
Ip = m_devget( (char *)Data, Len, 0, NULL, NULL ); Ip = m_devget( (char *)Data, Len, 0, NULL, NULL );
if( !Ip ) if( !Ip )
{ {
OSKUnlockAndLower(OldIrql); OSKUnlock();
return; /* drop the segment */ return; /* drop the segment */
} }
@ -461,7 +462,7 @@ void OskitTCPReceiveDatagram( OSK_PCHAR Data, OSK_UINT Len,
IpHeaderLen)); IpHeaderLen));
tcp_input(Ip, IpHeaderLen); tcp_input(Ip, IpHeaderLen);
OSKUnlockAndLower(OldIrql); OSKUnlock();
/* The buffer Ip is freed by tcp_input */ /* The buffer Ip is freed by tcp_input */
} }