[TCPIP, IP]

- Completely rewrite (again) the locking code and use references to ensure that the connection doesn't get freed while completing requests (the most frequent cause of crashes)
 - Remove DrainSignals and complete requests inside HandleSignalledConnection instead of doing them in a separate thread (increases speed a lot)
[OSKITTCP]
 - Don't clear the socket context in OskitTCPClose because we would end up in HandleSignalledConnection without a connection (which we don't support anymore after eliminating DrainSignals)
 - Change the check performed to see if a socket is dying so we support connection dying after calling OskitTCPClose
[AFD]
 - Remove leftover ASSERTs which fail after the changes to tcpip (they were wrong in the first place because we call into tcpip at DISPATCH_LEVEL sometimes)

svn path=/branches/aicom-network-branch/; revision=44839
This commit is contained in:
Cameron Gutman 2009-12-31 23:33:24 +00:00
parent 175500c30b
commit 8dd3966ba9
15 changed files with 302 additions and 320 deletions

View file

@ -189,8 +189,6 @@ NTSTATUS NTAPI ReceiveComplete
AFD_DbgPrint(MID_TRACE,("Called\n")); AFD_DbgPrint(MID_TRACE,("Called\n"));
ASSERT_IRQL(APC_LEVEL);
if( !SocketAcquireStateLock( FCB ) ) if( !SocketAcquireStateLock( FCB ) )
return STATUS_FILE_CLOSED; return STATUS_FILE_CLOSED;

View file

@ -38,8 +38,6 @@ static NTSTATUS NTAPI SendComplete
Irp->IoStatus.Status, Irp->IoStatus.Status,
Irp->IoStatus.Information)); Irp->IoStatus.Information));
ASSERT_IRQL(APC_LEVEL);
if( !SocketAcquireStateLock( FCB ) ) if( !SocketAcquireStateLock( FCB ) )
return STATUS_FILE_CLOSED; return STATUS_FILE_CLOSED;

View file

@ -53,13 +53,10 @@ extern DWORD DebugTraceLevel;
#define ASSERT(x) if (!(x)) { AFD_DbgPrint(MIN_TRACE, ("Assertion "#x" failed at %s:%d\n", __FILE__, __LINE__)); DbgBreakPoint(); } #define ASSERT(x) if (!(x)) { AFD_DbgPrint(MIN_TRACE, ("Assertion "#x" failed at %s:%d\n", __FILE__, __LINE__)); DbgBreakPoint(); }
#endif /* NASSERT */ #endif /* NASSERT */
#define ASSERT_IRQL(x) ASSERT(KeGetCurrentIrql() <= (x))
#else /* DBG */ #else /* DBG */
#define AFD_DbgPrint(_t_, _x_) #define AFD_DbgPrint(_t_, _x_)
#define ASSERT_IRQL(x)
#define ASSERTKM(x) #define ASSERTKM(x)
#ifndef ASSERT #ifndef ASSERT
#define ASSERT(x) #define ASSERT(x)
@ -70,7 +67,6 @@ extern DWORD DebugTraceLevel;
#undef assert #undef assert
#define assert(x) ASSERT(x) #define assert(x) ASSERT(x)
#define assert_irql(x) ASSERT_IRQL(x)
#ifdef _MSC_VER #ifdef _MSC_VER

View file

@ -62,6 +62,7 @@ typedef struct _SLEEPING_THREAD {
typedef struct _CLIENT_DATA { typedef struct _CLIENT_DATA {
BOOLEAN Unlocked; BOOLEAN Unlocked;
KSPIN_LOCK Lock; KSPIN_LOCK Lock;
KIRQL OldIrql;
} CLIENT_DATA, *PCLIENT_DATA; } CLIENT_DATA, *PCLIENT_DATA;
/* Retransmission timeout constants */ /* Retransmission timeout constants */

View file

@ -7,79 +7,63 @@
#ifndef __TITYPES_H #ifndef __TITYPES_H
#define __TITYPES_H #define __TITYPES_H
#if DBG
#define DEBUG_REFCHECK(Object) { \
if ((Object)->RefCount <= 0) { \
TI_DbgPrint(MIN_TRACE, ("Object at (0x%X) has invalid reference count (%d).\n", \
(Object), (Object)->RefCount)); \
} \
}
/* /*
* VOID ReferenceObject( * VOID ReferenceObject(
* PVOID Object) * PVOID Object)
*/ */
#define ReferenceObject(Object) \ #define ReferenceObject(Object) \
{ \ { \
CHAR c1, c2, c3, c4; \ InterlockedIncrement(&((Object)->RefCount)); \
\
c1 = ((Object)->Tag >> 24) & 0xFF; \
c2 = ((Object)->Tag >> 16) & 0xFF; \
c3 = ((Object)->Tag >> 8) & 0xFF; \
c4 = ((Object)->Tag & 0xFF); \
\
DEBUG_REFCHECK(Object); \
TI_DbgPrint(DEBUG_REFCOUNT, ("Referencing object of type (%c%c%c%c) at (0x%X). RefCount (%d).\n", \
c4, c3, c2, c1, (Object), (Object)->RefCount)); \
\
InterlockedIncrement(&((Object)->RefCount)); \
}
/*
* VOID DereferenceObject(
* PVOID Object)
*/
#define DereferenceObject(Object) \
{ \
CHAR c1, c2, c3, c4; \
\
c1 = ((Object)->Tag >> 24) & 0xFF; \
c2 = ((Object)->Tag >> 16) & 0xFF; \
c3 = ((Object)->Tag >> 8) & 0xFF; \
c4 = ((Object)->Tag & 0xFF); \
\
DEBUG_REFCHECK(Object); \
TI_DbgPrint(DEBUG_REFCOUNT, ("Dereferencing object of type (%c%c%c%c) at (0x%X). RefCount (%d).\n", \
c4, c3, c2, c1, (Object), (Object)->RefCount)); \
\
if (InterlockedDecrement(&((Object)->RefCount)) == 0) \
(((Object)->Free)(Object)); \
}
#else /* DBG */
/*
* VOID ReferenceObject(
* PVOID Object)
*/
#define ReferenceObject(Object) \
{ \
InterlockedIncrement(&((Object)->RefCount)); \
} }
/* /*
* VOID DereferenceObject( * VOID DereferenceObject(
* PVOID Object) * PVOID Object)
*/ */
#define DereferenceObject(Object) \ #define DereferenceObject(Object) \
{ \ { \
if (InterlockedDecrement(&((Object)->RefCount)) == 0) \ if (InterlockedDecrement(&((Object)->RefCount)) == 0) \
(((Object)->Free)(Object)); \ (((Object)->Free)(Object)); \
}
/*
* VOID LockObject(PVOID Object, PKIRQL OldIrql)
*/
#define LockObject(Object, Irql) \
{ \
ReferenceObject(Object); \
KeAcquireSpinLock(&((Object)->Lock), Irql); \
memcpy(&(Object)->OldIrql, Irql, sizeof(KIRQL)); \
}
/*
* VOID LockObjectAtDpcLevel(PVOID Object)
*/
#define LockObjectAtDpcLevel(Object) \
{ \
ReferenceObject(Object); \
KeAcquireSpinLockAtDpcLevel(&((Object)->Lock)); \
(Object)->OldIrql = DISPATCH_LEVEL; \
}
/*
* VOID UnlockObject(PVOID Object, KIRQL OldIrql)
*/
#define UnlockObject(Object, OldIrql) \
{ \
KeReleaseSpinLock(&((Object)->Lock), OldIrql); \
DereferenceObject(Object); \
}
/*
* VOID UnlockObjectFromDpcLevel(PVOID Object)
*/
#define UnlockObjectFromDpcLevel(Object) \
{ \
KeReleaseSpinLockFromDpcLevel(&((Object)->Lock)); \
DereferenceObject(Object); \
} }
#endif /* DBG */
#include <ip.h> #include <ip.h>
@ -143,8 +127,10 @@ typedef struct _DATAGRAM_SEND_REQUEST {
field holds a pointer to this structure */ field holds a pointer to this structure */
typedef struct _ADDRESS_FILE { typedef struct _ADDRESS_FILE {
LIST_ENTRY ListEntry; /* Entry on list */ LIST_ENTRY ListEntry; /* Entry on list */
KSPIN_LOCK Lock; /* Spin lock to manipulate this structure */ LONG RefCount; /* Reference count */
OBJECT_FREE_ROUTINE Free; /* Routine to use to free resources for the object */ OBJECT_FREE_ROUTINE Free; /* Routine to use to free resources for the object */
KSPIN_LOCK Lock; /* Spin lock to manipulate this structure */
KIRQL OldIrql; /* Currently not used */
IP_ADDRESS Address; /* Address of this address file */ IP_ADDRESS Address; /* Address of this address file */
USHORT Family; /* Address family */ USHORT Family; /* Address family */
USHORT Protocol; /* Protocol number */ USHORT Protocol; /* Protocol number */
@ -264,7 +250,10 @@ typedef struct _TDI_BUCKET {
to this structure */ to this structure */
typedef struct _CONNECTION_ENDPOINT { typedef struct _CONNECTION_ENDPOINT {
LIST_ENTRY ListEntry; /* Entry on list */ LIST_ENTRY ListEntry; /* Entry on list */
LONG RefCount; /* Reference count */
OBJECT_FREE_ROUTINE Free; /* Routine to use to free resources for the object */
KSPIN_LOCK Lock; /* Spin lock to protect this structure */ KSPIN_LOCK Lock; /* Spin lock to protect this structure */
KIRQL OldIrql; /* The old irql is stored here for use in HandleSignalledConnection */
PVOID ClientContext; /* Pointer to client context information */ PVOID ClientContext; /* Pointer to client context information */
PADDRESS_FILE AddressFile; /* Associated address file object (NULL if none) */ PADDRESS_FILE AddressFile; /* Associated address file object (NULL if none) */
PVOID SocketContext; /* Context for lower layer */ PVOID SocketContext; /* Context for lower layer */
@ -290,6 +279,8 @@ typedef struct _CONNECTION_ENDPOINT {
field holds a pointer to this structure */ field holds a pointer to this structure */
typedef struct _CONTROL_CHANNEL { typedef struct _CONTROL_CHANNEL {
LIST_ENTRY ListEntry; /* Entry on list */ LIST_ENTRY ListEntry; /* Entry on list */
LONG RefCount; /* Reference count */
OBJECT_FREE_ROUTINE Free; /* Routine to use to free resources for the object */
KSPIN_LOCK Lock; /* Spin lock to protect this structure */ KSPIN_LOCK Lock; /* Spin lock to protect this structure */
} CONTROL_CHANNEL, *PCONTROL_CHANNEL; } CONTROL_CHANNEL, *PCONTROL_CHANNEL;

View file

@ -89,21 +89,10 @@ VOID DispDataRequestComplete(
* Count = Number of bytes sent or received * Count = Number of bytes sent or received
*/ */
{ {
PIRP Irp; PIRP Irp = Context;
PIO_STACK_LOCATION IrpSp;
KIRQL OldIrql;
TI_DbgPrint(DEBUG_IRP, ("Called for irp %x (%x, %d).\n", TI_DbgPrint(DEBUG_IRP, ("Called for irp %x (%x, %d).\n",
Context, Status, Count)); Irp, Status, Count));
Irp = Context;
IrpSp = IoGetCurrentIrpStackLocation(Irp);
IoAcquireCancelSpinLock(&OldIrql);
(void)IoSetCancelRoutine(Irp, NULL);
IoReleaseCancelSpinLock(OldIrql);
Irp->IoStatus.Status = Status; Irp->IoStatus.Status = Status;
Irp->IoStatus.Information = Count; Irp->IoStatus.Information = Count;
@ -309,18 +298,18 @@ NTSTATUS DispTdiAssociateAddress(
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
if (Connection->AddressFile) { if (Connection->AddressFile) {
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
TI_DbgPrint(MID_TRACE, ("An address file is already asscociated.\n")); TI_DbgPrint(MID_TRACE, ("An address file is already asscociated.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
if (FileObject->FsContext2 != (PVOID)TDI_TRANSPORT_ADDRESS_FILE) { if (FileObject->FsContext2 != (PVOID)TDI_TRANSPORT_ADDRESS_FILE) {
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
TI_DbgPrint(MID_TRACE, ("Bad address file object. Magic (0x%X).\n", TI_DbgPrint(MID_TRACE, ("Bad address file object. Magic (0x%X).\n",
FileObject->FsContext2)); FileObject->FsContext2));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
@ -331,31 +320,33 @@ NTSTATUS DispTdiAssociateAddress(
TranContext = FileObject->FsContext; TranContext = FileObject->FsContext;
if (!TranContext) { if (!TranContext) {
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
TI_DbgPrint(MID_TRACE, ("Bad transport context.\n")); TI_DbgPrint(MID_TRACE, ("Bad transport context.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle; AddrFile = (PADDRESS_FILE)TranContext->Handle.AddressHandle;
if (!AddrFile) { if (!AddrFile) {
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
TI_DbgPrint(MID_TRACE, ("No address file object.\n")); TI_DbgPrint(MID_TRACE, ("No address file object.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLockAtDpcLevel(&AddrFile->Lock); LockObjectAtDpcLevel(AddrFile);
ReferenceObject(AddrFile);
Connection->AddressFile = AddrFile; Connection->AddressFile = AddrFile;
/* Add connection endpoint to the address file */ /* Add connection endpoint to the address file */
ReferenceObject(Connection);
AddrFile->Connection = Connection; AddrFile->Connection = Connection;
/* FIXME: Maybe do this in DispTdiDisassociateAddress() instead? */ /* FIXME: Maybe do this in DispTdiDisassociateAddress() instead? */
ObDereferenceObject(FileObject); ObDereferenceObject(FileObject);
KeReleaseSpinLockFromDpcLevel(&AddrFile->Lock); UnlockObjectFromDpcLevel(AddrFile);
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
return Status; return Status;
} }
@ -457,25 +448,27 @@ NTSTATUS DispTdiDisassociateAddress(
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
if (!Connection->AddressFile) { if (!Connection->AddressFile) {
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
TI_DbgPrint(MID_TRACE, ("No address file is asscociated.\n")); TI_DbgPrint(MID_TRACE, ("No address file is asscociated.\n"));
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
KeAcquireSpinLockAtDpcLevel(&Connection->AddressFile->Lock); LockObjectAtDpcLevel(Connection->AddressFile);
/* Remove this connection from the address file */ /* Remove this connection from the address file */
DereferenceObject(Connection->AddressFile->Connection);
Connection->AddressFile->Connection = NULL; Connection->AddressFile->Connection = NULL;
KeReleaseSpinLockFromDpcLevel(&Connection->AddressFile->Lock); UnlockObjectFromDpcLevel(Connection->AddressFile);
/* Remove the address file from this connection */ /* Remove the address file from this connection */
DereferenceObject(Connection->AddressFile);
Connection->AddressFile = NULL; Connection->AddressFile = NULL;
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
@ -584,17 +577,17 @@ NTSTATUS DispTdiListen(
Irp, Irp,
(PDRIVER_CANCEL)DispCancelListenRequest); (PDRIVER_CANCEL)DispCancelListenRequest);
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
if (Connection->AddressFile == NULL) if (Connection->AddressFile == NULL)
{ {
TI_DbgPrint(MID_TRACE, ("No associated address file\n")); TI_DbgPrint(MID_TRACE, ("No associated address file\n"));
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
Status = STATUS_INVALID_PARAMETER; Status = STATUS_INVALID_PARAMETER;
goto done; goto done;
} }
KeAcquireSpinLockAtDpcLevel(&Connection->AddressFile->Lock); LockObjectAtDpcLevel(Connection->AddressFile);
/* Listening will require us to create a listening socket and store it in /* Listening will require us to create a listening socket and store it in
* the address file. It will be signalled, and attempt to complete an irp * the address file. It will be signalled, and attempt to complete an irp
@ -609,6 +602,7 @@ NTSTATUS DispTdiListen(
Status = STATUS_NO_MEMORY; Status = STATUS_NO_MEMORY;
if( NT_SUCCESS(Status) ) { if( NT_SUCCESS(Status) ) {
ReferenceObject(Connection->AddressFile);
Connection->AddressFile->Listener->AddressFile = Connection->AddressFile->Listener->AddressFile =
Connection->AddressFile; Connection->AddressFile;
@ -632,8 +626,8 @@ NTSTATUS DispTdiListen(
Irp ); Irp );
} }
KeReleaseSpinLockFromDpcLevel(&Connection->AddressFile->Lock); UnlockObjectFromDpcLevel(Connection->AddressFile);
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
done: done:
if (Status != STATUS_PENDING) { if (Status != STATUS_PENDING) {
@ -1106,7 +1100,7 @@ NTSTATUS DispTdiSetEventHandler(PIRP Irp)
Parameters = (PTDI_REQUEST_KERNEL_SET_EVENT)&IrpSp->Parameters; Parameters = (PTDI_REQUEST_KERNEL_SET_EVENT)&IrpSp->Parameters;
Status = STATUS_SUCCESS; Status = STATUS_SUCCESS;
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql); LockObject(AddrFile, &OldIrql);
/* Set the event handler. if an event handler is associated with /* Set the event handler. if an event handler is associated with
a specific event, it's flag (RegisteredXxxHandler) is TRUE. a specific event, it's flag (RegisteredXxxHandler) is TRUE.
@ -1227,7 +1221,7 @@ NTSTATUS DispTdiSetEventHandler(PIRP Irp)
Status = STATUS_INVALID_PARAMETER; Status = STATUS_INVALID_PARAMETER;
} }
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return Status; return Status;
} }

View file

@ -153,7 +153,55 @@ VOID AddrFileFree(
* Object = Pointer to address file object to free * Object = Pointer to address file object to free
*/ */
{ {
ExFreePoolWithTag(Object, ADDR_FILE_TAG); PADDRESS_FILE AddrFile = Object;
KIRQL OldIrql;
PDATAGRAM_RECEIVE_REQUEST ReceiveRequest;
PDATAGRAM_SEND_REQUEST SendRequest;
PLIST_ENTRY CurrentEntry;
TI_DbgPrint(MID_TRACE, ("Called.\n"));
/* Remove address file from the global list */
TcpipAcquireSpinLock(&AddressFileListLock, &OldIrql);
RemoveEntryList(&AddrFile->ListEntry);
TcpipReleaseSpinLock(&AddressFileListLock, OldIrql);
/* FIXME: Kill TCP connections on this address file object */
/* Return pending requests with error */
TI_DbgPrint(DEBUG_ADDRFILE, ("Aborting receive requests on AddrFile at (0x%X).\n", AddrFile));
/* Go through pending receive request list and cancel them all */
while ((CurrentEntry = ExInterlockedRemoveHeadList(&AddrFile->ReceiveQueue, &AddrFile->Lock))) {
ReceiveRequest = CONTAINING_RECORD(CurrentEntry, DATAGRAM_RECEIVE_REQUEST, ListEntry);
(*ReceiveRequest->Complete)(ReceiveRequest->Context, STATUS_CANCELLED, 0);
/* ExFreePoolWithTag(ReceiveRequest, DATAGRAM_RECV_TAG); FIXME: WTF? */
}
TI_DbgPrint(DEBUG_ADDRFILE, ("Aborting send requests on address file at (0x%X).\n", AddrFile));
/* Go through pending send request list and cancel them all */
while ((CurrentEntry = ExInterlockedRemoveHeadList(&AddrFile->ReceiveQueue, &AddrFile->Lock))) {
SendRequest = CONTAINING_RECORD(CurrentEntry, DATAGRAM_SEND_REQUEST, ListEntry);
(*SendRequest->Complete)(SendRequest->Context, STATUS_CANCELLED, 0);
ExFreePoolWithTag(SendRequest, DATAGRAM_SEND_TAG);
}
/* Protocol specific handling */
switch (AddrFile->Protocol) {
case IPPROTO_TCP:
TCPFreePort( AddrFile->Port );
break;
case IPPROTO_UDP:
UDPFreePort( AddrFile->Port );
break;
}
RemoveEntityByContext(AddrFile);
ExFreePoolWithTag(Object, ADDR_FILE_TAG);
} }
@ -200,6 +248,7 @@ NTSTATUS FileOpenAddress(
RtlZeroMemory(AddrFile, sizeof(ADDRESS_FILE)); RtlZeroMemory(AddrFile, sizeof(ADDRESS_FILE));
AddrFile->RefCount = 1;
AddrFile->Free = AddrFileFree; AddrFile->Free = AddrFileFree;
/* Set our default TTL */ /* Set our default TTL */
@ -321,64 +370,24 @@ NTSTATUS FileOpenAddress(
NTSTATUS FileCloseAddress( NTSTATUS FileCloseAddress(
PTDI_REQUEST Request) PTDI_REQUEST Request)
{ {
PADDRESS_FILE AddrFile; PADDRESS_FILE AddrFile = Request->Handle.AddressHandle;
NTSTATUS Status = STATUS_SUCCESS;
KIRQL OldIrql; KIRQL OldIrql;
PDATAGRAM_RECEIVE_REQUEST ReceiveRequest;
PDATAGRAM_SEND_REQUEST SendRequest;
PLIST_ENTRY CurrentEntry;
AddrFile = Request->Handle.AddressHandle; if (!Request->Handle.AddressHandle) return STATUS_INVALID_PARAMETER;
TI_DbgPrint(MID_TRACE, ("Called.\n")); LockObject(AddrFile, &OldIrql);
/* We have to close this connection because we started it */
if( AddrFile->Listener )
TCPClose( AddrFile->Listener );
if( AddrFile->Connection )
DereferenceObject( AddrFile->Connection );
UnlockObject(AddrFile, OldIrql);
/* Remove address file from the global list */ DereferenceObject(AddrFile);
TcpipAcquireSpinLock(&AddressFileListLock, &OldIrql);
RemoveEntryList(&AddrFile->ListEntry);
TcpipReleaseSpinLock(&AddressFileListLock, OldIrql);
/* FIXME: Kill TCP connections on this address file object */
/* Return pending requests with error */
TI_DbgPrint(DEBUG_ADDRFILE, ("Aborting receive requests on AddrFile at (0x%X).\n", AddrFile));
/* Go through pending receive request list and cancel them all */
while ((CurrentEntry = ExInterlockedRemoveHeadList(&AddrFile->ReceiveQueue, &AddrFile->Lock))) {
ReceiveRequest = CONTAINING_RECORD(CurrentEntry, DATAGRAM_RECEIVE_REQUEST, ListEntry);
(*ReceiveRequest->Complete)(ReceiveRequest->Context, STATUS_CANCELLED, 0);
/* ExFreePoolWithTag(ReceiveRequest, DATAGRAM_RECV_TAG); FIXME: WTF? */
}
TI_DbgPrint(DEBUG_ADDRFILE, ("Aborting send requests on address file at (0x%X).\n", AddrFile));
/* Go through pending send request list and cancel them all */
while ((CurrentEntry = ExInterlockedRemoveHeadList(&AddrFile->ReceiveQueue, &AddrFile->Lock))) {
SendRequest = CONTAINING_RECORD(CurrentEntry, DATAGRAM_SEND_REQUEST, ListEntry);
(*SendRequest->Complete)(SendRequest->Context, STATUS_CANCELLED, 0);
ExFreePoolWithTag(SendRequest, DATAGRAM_SEND_TAG);
}
/* Protocol specific handling */
switch (AddrFile->Protocol) {
case IPPROTO_TCP:
TCPFreePort( AddrFile->Port );
if( AddrFile->Listener )
TCPClose( AddrFile->Listener );
break;
case IPPROTO_UDP:
UDPFreePort( AddrFile->Port );
break;
}
RemoveEntityByContext(AddrFile);
(*AddrFile->Free)(AddrFile);
TI_DbgPrint(MAX_TRACE, ("Leaving.\n")); TI_DbgPrint(MAX_TRACE, ("Leaving.\n"));
return Status; return STATUS_SUCCESS;
} }
@ -406,7 +415,7 @@ NTSTATUS FileOpenConnection(
Status = TCPSocket( Connection, AF_INET, SOCK_STREAM, IPPROTO_TCP ); Status = TCPSocket( Connection, AF_INET, SOCK_STREAM, IPPROTO_TCP );
if( !NT_SUCCESS(Status) ) { if( !NT_SUCCESS(Status) ) {
TCPFreeConnectionEndpoint( Connection ); DereferenceObject( Connection );
return Status; return Status;
} }
@ -434,14 +443,17 @@ NTSTATUS FileCloseConnection(
Connection = Request->Handle.ConnectionContext; Connection = Request->Handle.ConnectionContext;
if (!Connection) return STATUS_INVALID_PARAMETER;
TCPClose( Connection ); TCPClose( Connection );
Request->Handle.ConnectionContext = NULL;
TI_DbgPrint(MAX_TRACE, ("Leaving.\n")); TI_DbgPrint(MAX_TRACE, ("Leaving.\n"));
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
/* /*
* FUNCTION: Opens a control channel file object * FUNCTION: Opens a control channel file object
* ARGUMENTS: * ARGUMENTS:
@ -475,6 +487,9 @@ NTSTATUS FileOpenControlChannel(
/* Initialize spin lock that protects the address file object */ /* Initialize spin lock that protects the address file object */
KeInitializeSpinLock(&ControlChannel->Lock); KeInitializeSpinLock(&ControlChannel->Lock);
ControlChannel->RefCount = 1;
ControlChannel->Free = ControlChannelFree;
/* Return address file object */ /* Return address file object */
Request->Handle.ControlChannel = ControlChannel; Request->Handle.ControlChannel = ControlChannel;
@ -493,13 +508,13 @@ NTSTATUS FileOpenControlChannel(
NTSTATUS FileCloseControlChannel( NTSTATUS FileCloseControlChannel(
PTDI_REQUEST Request) PTDI_REQUEST Request)
{ {
PCONTROL_CHANNEL ControlChannel = Request->Handle.ControlChannel; if (!Request->Handle.ControlChannel) return STATUS_INVALID_PARAMETER;
NTSTATUS Status = STATUS_SUCCESS;
DereferenceObject((PCONTROL_CHANNEL)Request->Handle.ControlChannel);
ExFreePoolWithTag(ControlChannel, CONTROL_CHANNEL_TAG);
Request->Handle.ControlChannel = NULL; Request->Handle.ControlChannel = NULL;
return Status; return STATUS_SUCCESS;
} }
/* EOF */ /* EOF */

View file

@ -22,7 +22,7 @@ BOOLEAN DGRemoveIRP(
TI_DbgPrint(MAX_TRACE, ("Called (Cancel IRP %08x for file %08x).\n", TI_DbgPrint(MAX_TRACE, ("Called (Cancel IRP %08x for file %08x).\n",
Irp, AddrFile)); Irp, AddrFile));
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql); LockObject(AddrFile, &OldIrql);
for( ListEntry = AddrFile->ReceiveQueue.Flink; for( ListEntry = AddrFile->ReceiveQueue.Flink;
ListEntry != &AddrFile->ReceiveQueue; ListEntry != &AddrFile->ReceiveQueue;
@ -42,7 +42,7 @@ BOOLEAN DGRemoveIRP(
} }
} }
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
TI_DbgPrint(MAX_TRACE, ("Done.\n")); TI_DbgPrint(MAX_TRACE, ("Done.\n"));
@ -83,7 +83,7 @@ VOID DGDeliverData(
TI_DbgPrint(MAX_TRACE, ("Called.\n")); TI_DbgPrint(MAX_TRACE, ("Called.\n"));
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql); LockObject(AddrFile, &OldIrql);
if (AddrFile->Protocol == IPPROTO_UDP) if (AddrFile->Protocol == IPPROTO_UDP)
{ {
@ -140,7 +140,8 @@ VOID DGDeliverData(
&SrcAddress->Address.IPv4Address, &SrcAddress->Address.IPv4Address,
sizeof(SrcAddress->Address.IPv4Address) ); sizeof(SrcAddress->Address.IPv4Address) );
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); ReferenceObject(AddrFile);
UnlockObject(AddrFile, OldIrql);
/* Complete the receive request */ /* Complete the receive request */
if (Current->BufferSize < DataSize) if (Current->BufferSize < DataSize)
@ -148,11 +149,12 @@ VOID DGDeliverData(
else else
Current->Complete(Current->Context, STATUS_SUCCESS, DataSize); Current->Complete(Current->Context, STATUS_SUCCESS, DataSize);
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql); LockObject(AddrFile, &OldIrql);
DereferenceObject(AddrFile);
} }
} }
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
} }
else if (AddrFile->RegisteredReceiveDatagramHandler) else if (AddrFile->RegisteredReceiveDatagramHandler)
{ {
@ -172,7 +174,8 @@ VOID DGDeliverData(
SourceAddress = SrcAddress->Address.IPv6Address; SourceAddress = SrcAddress->Address.IPv6Address;
} }
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); ReferenceObject(AddrFile);
UnlockObject(AddrFile, OldIrql);
Status = (*ReceiveHandler)(HandlerContext, Status = (*ReceiveHandler)(HandlerContext,
AddressLength, AddressLength,
@ -185,10 +188,12 @@ VOID DGDeliverData(
&BytesTaken, &BytesTaken,
DataBuffer, DataBuffer,
NULL); NULL);
DereferenceObject(AddrFile);
} }
else else
{ {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
TI_DbgPrint(MAX_TRACE, ("Discarding datagram.\n")); TI_DbgPrint(MAX_TRACE, ("Discarding datagram.\n"));
} }
@ -238,7 +243,7 @@ NTSTATUS DGReceiveDatagram(
TI_DbgPrint(MAX_TRACE, ("Called.\n")); TI_DbgPrint(MAX_TRACE, ("Called.\n"));
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql); LockObject(AddrFile, &OldIrql);
ReceiveRequest = ExAllocatePoolWithTag(NonPagedPool, sizeof(DATAGRAM_RECEIVE_REQUEST), ReceiveRequest = ExAllocatePoolWithTag(NonPagedPool, sizeof(DATAGRAM_RECEIVE_REQUEST),
DATAGRAM_RECV_TAG); DATAGRAM_RECV_TAG);
@ -256,7 +261,7 @@ NTSTATUS DGReceiveDatagram(
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
ExFreePoolWithTag(ReceiveRequest, DATAGRAM_RECV_TAG); ExFreePoolWithTag(ReceiveRequest, DATAGRAM_RECV_TAG);
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return Status; return Status;
} }
} }
@ -284,13 +289,13 @@ NTSTATUS DGReceiveDatagram(
TI_DbgPrint(MAX_TRACE, ("Leaving (pending %08x).\n", ReceiveRequest)); TI_DbgPrint(MAX_TRACE, ("Leaving (pending %08x).\n", ReceiveRequest));
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_PENDING; return STATUS_PENDING;
} }
else else
{ {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
Status = STATUS_INSUFFICIENT_RESOURCES; Status = STATUS_INSUFFICIENT_RESOURCES;
} }

View file

@ -197,7 +197,7 @@ NTSTATUS RawIPSendDatagram(
PNEIGHBOR_CACHE_ENTRY NCE; PNEIGHBOR_CACHE_ENTRY NCE;
KIRQL OldIrql; KIRQL OldIrql;
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql); LockObject(AddrFile, &OldIrql);
TI_DbgPrint(MID_TRACE,("Sending Datagram(%x %x %x %d)\n", TI_DbgPrint(MID_TRACE,("Sending Datagram(%x %x %x %d)\n",
AddrFile, ConnInfo, BufferData, DataSize)); AddrFile, ConnInfo, BufferData, DataSize));
@ -212,7 +212,7 @@ NTSTATUS RawIPSendDatagram(
break; break;
default: default:
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_UNSUCCESSFUL; return STATUS_UNSUCCESSFUL;
} }
@ -226,7 +226,7 @@ NTSTATUS RawIPSendDatagram(
* interface we're sending over * interface we're sending over
*/ */
if(!(NCE = RouteGetRouteToDestination( &RemoteAddress ))) { if(!(NCE = RouteGetRouteToDestination( &RemoteAddress ))) {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_NETWORK_UNREACHABLE; return STATUS_NETWORK_UNREACHABLE;
} }
@ -235,7 +235,7 @@ NTSTATUS RawIPSendDatagram(
else else
{ {
if(!(NCE = NBLocateNeighbor( &LocalAddress ))) { if(!(NCE = NBLocateNeighbor( &LocalAddress ))) {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
} }
@ -251,7 +251,7 @@ NTSTATUS RawIPSendDatagram(
if( !NT_SUCCESS(Status) ) if( !NT_SUCCESS(Status) )
{ {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return Status; return Status;
} }
@ -259,14 +259,14 @@ NTSTATUS RawIPSendDatagram(
if (!NT_SUCCESS(Status = IPSendDatagram( &Packet, NCE, RawIpSendPacketComplete, NULL ))) if (!NT_SUCCESS(Status = IPSendDatagram( &Packet, NCE, RawIpSendPacketComplete, NULL )))
{ {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
FreeNdisPacket(Packet.NdisPacket); FreeNdisPacket(Packet.NdisPacket);
return Status; return Status;
} }
TI_DbgPrint(MID_TRACE,("Leaving\n")); TI_DbgPrint(MID_TRACE,("Leaving\n"));
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }

View file

@ -72,10 +72,9 @@ NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
KIRQL OldIrql; KIRQL OldIrql;
ASSERT(Connection); ASSERT(Connection);
ASSERT_KM_POINTER(Connection->SocketContext);
ASSERT_KM_POINTER(Connection->AddressFile); ASSERT_KM_POINTER(Connection->AddressFile);
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
TI_DbgPrint(DEBUG_TCP,("TCPListen started\n")); TI_DbgPrint(DEBUG_TCP,("TCPListen started\n"));
@ -97,7 +96,7 @@ NTSTATUS TCPListen( PCONNECTION_ENDPOINT Connection, UINT Backlog ) {
if (NT_SUCCESS(Status)) if (NT_SUCCESS(Status))
Status = TCPTranslateError( OskitTCPListen( Connection->SocketContext, Backlog ) ); Status = TCPTranslateError( OskitTCPListen( Connection->SocketContext, Backlog ) );
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
TI_DbgPrint(DEBUG_TCP,("TCPListen finished %x\n", Status)); TI_DbgPrint(DEBUG_TCP,("TCPListen finished %x\n", Status));
@ -111,7 +110,7 @@ BOOLEAN TCPAbortListenForSocket( PCONNECTION_ENDPOINT Listener,
KIRQL OldIrql; KIRQL OldIrql;
BOOLEAN Found = FALSE; BOOLEAN Found = FALSE;
KeAcquireSpinLock(&Listener->Lock, &OldIrql); LockObject(Listener, &OldIrql);
ListEntry = Listener->ListenRequest.Flink; ListEntry = Listener->ListenRequest.Flink;
while ( ListEntry != &Listener->ListenRequest ) { while ( ListEntry != &Listener->ListenRequest ) {
@ -127,7 +126,7 @@ BOOLEAN TCPAbortListenForSocket( PCONNECTION_ENDPOINT Listener,
ListEntry = ListEntry->Flink; ListEntry = ListEntry->Flink;
} }
KeReleaseSpinLock(&Listener->Lock, OldIrql); UnlockObject(Listener, OldIrql);
return Found; return Found;
} }
@ -144,27 +143,27 @@ NTSTATUS TCPAccept ( PTDI_REQUEST Request,
TI_DbgPrint(DEBUG_TCP,("TCPAccept started\n")); TI_DbgPrint(DEBUG_TCP,("TCPAccept started\n"));
KeAcquireSpinLock(&Listener->Lock, &OldIrql); LockObject(Listener, &OldIrql);
Status = TCPServiceListeningSocket( Listener, Connection, Status = TCPServiceListeningSocket( Listener, Connection,
(PTDI_REQUEST_KERNEL)Request ); (PTDI_REQUEST_KERNEL)Request );
KeReleaseSpinLock(&Listener->Lock, OldIrql);
if( Status == STATUS_PENDING ) { if( Status == STATUS_PENDING ) {
Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket),
TDI_BUCKET_TAG ); TDI_BUCKET_TAG );
if( Bucket ) { if( Bucket ) {
ReferenceObject(Connection);
Bucket->AssociatedEndpoint = Connection; Bucket->AssociatedEndpoint = Connection;
Bucket->Request.RequestNotifyObject = Complete; Bucket->Request.RequestNotifyObject = Complete;
Bucket->Request.RequestContext = Context; Bucket->Request.RequestContext = Context;
ExInterlockedInsertTailList( &Listener->ListenRequest, &Bucket->Entry, InsertTailList( &Listener->ListenRequest, &Bucket->Entry );
&Listener->Lock );
} else } else
Status = STATUS_NO_MEMORY; Status = STATUS_NO_MEMORY;
} }
UnlockObject(Listener, OldIrql);
TI_DbgPrint(DEBUG_TCP,("TCPAccept finished %x\n", Status)); TI_DbgPrint(DEBUG_TCP,("TCPAccept finished %x\n", Status));
return Status; return Status;
} }

View file

@ -24,26 +24,17 @@ int TCPSocketState(void *ClientData,
NewState & SEL_ACCEPT ? 'A' : 'a', NewState & SEL_ACCEPT ? 'A' : 'a',
NewState & SEL_WRITE ? 'W' : 'w')); NewState & SEL_WRITE ? 'W' : 'w'));
if (!Connection) ASSERT(Connection);
{
return 0;
}
if (ClientInfo.Unlocked)
KeAcquireSpinLockAtDpcLevel(&Connection->Lock);
TI_DbgPrint(DEBUG_TCP,("Called: NewState %x (Conn %x) (Change %x)\n", TI_DbgPrint(DEBUG_TCP,("Called: NewState %x (Conn %x) (Change %x)\n",
NewState, Connection, NewState, Connection,
Connection->SignalState ^ NewState, Connection->SignalState ^ NewState,
NewState)); NewState));
Connection->SignalState |= NewState; Connection->SignalState = NewState;
HandleSignalledConnection(Connection); HandleSignalledConnection(Connection);
if (ClientInfo.Unlocked)
KeReleaseSpinLockFromDpcLevel(&Connection->Lock);
return 0; return 0;
} }
@ -76,8 +67,8 @@ int TCPPacketSend(void *ClientData, OSK_PCHAR data, OSK_UINT len ) {
return OSK_EINVAL; return OSK_EINVAL;
} }
if(!(NCE = NBLocateNeighbor( &LocalAddress ))) { if(!(NCE = RouteGetRouteToDestination( &RemoteAddress ))) {
TI_DbgPrint(MIN_TRACE,("Interface doesn't exist! %s\n", A2S(&LocalAddress))); TI_DbgPrint(MIN_TRACE,("Unable to get route to %s\n", A2S(&RemoteAddress)));
return OSK_EADDRNOTAVAIL; return OSK_EADDRNOTAVAIL;
} }

View file

@ -18,28 +18,6 @@ static NPAGED_LOOKASIDE_LIST TCPSegmentList;
PORT_SET TCPPorts; PORT_SET TCPPorts;
CLIENT_DATA ClientInfo; CLIENT_DATA ClientInfo;
static VOID
ProcessCompletions(PCONNECTION_ENDPOINT Connection)
{
PLIST_ENTRY CurrentEntry;
PTDI_BUCKET Bucket;
PTCP_COMPLETION_ROUTINE Complete;
while ((CurrentEntry = ExInterlockedRemoveHeadList(&Connection->CompletionQueue,
&Connection->Lock)))
{
Bucket = CONTAINING_RECORD(CurrentEntry, TDI_BUCKET, Entry);
Complete = Bucket->Request.RequestNotifyObject;
Complete(Bucket->Request.RequestContext, Bucket->Status, Bucket->Information);
ExFreePoolWithTag(Bucket, TDI_BUCKET_TAG);
}
if (!Connection->SocketContext)
TCPFreeConnectionEndpoint(Connection);
}
VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection) VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection)
{ {
PTDI_BUCKET Bucket; PTDI_BUCKET Bucket;
@ -48,11 +26,16 @@ VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection)
PIRP Irp; PIRP Irp;
PMDL Mdl; PMDL Mdl;
ULONG SocketError = 0; ULONG SocketError = 0;
KIRQL OldIrql;
PTCP_COMPLETION_ROUTINE Complete;
if (ClientInfo.Unlocked)
LockObjectAtDpcLevel(Connection);
TI_DbgPrint(MID_TRACE,("Handling signalled state on %x (%x)\n", TI_DbgPrint(MID_TRACE,("Handling signalled state on %x (%x)\n",
Connection, Connection->SocketContext)); Connection, Connection->SocketContext));
if( !Connection->SocketContext || Connection->SignalState & SEL_FIN ) { if( Connection->SignalState & SEL_FIN ) {
TI_DbgPrint(DEBUG_TCP, ("EOF From socket\n")); TI_DbgPrint(DEBUG_TCP, ("EOF From socket\n"));
/* If OskitTCP initiated the disconnect, try to read the socket error that occurred */ /* If OskitTCP initiated the disconnect, try to read the socket error that occurred */
@ -95,6 +78,7 @@ VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection)
Bucket->Status = SocketError; Bucket->Status = SocketError;
Bucket->Information = 0; Bucket->Information = 0;
DereferenceObject(Bucket->AssociatedEndpoint);
InsertTailList(&Connection->CompletionQueue, &Bucket->Entry); InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
@ -111,7 +95,7 @@ VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection)
InsertTailList(&Connection->CompletionQueue, &Bucket->Entry); InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
Connection->SignalState = 0; Connection->SignalState = SEL_FIN;
} }
/* Things that can happen when we try the initial connection */ /* Things that can happen when we try the initial connection */
@ -162,6 +146,7 @@ VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection)
} else { } else {
Bucket->Status = Status; Bucket->Status = Status;
Bucket->Information = 0; Bucket->Information = 0;
DereferenceObject(Bucket->AssociatedEndpoint);
InsertTailList(&Connection->CompletionQueue, &Bucket->Entry); InsertTailList(&Connection->CompletionQueue, &Bucket->Entry);
} }
@ -278,39 +263,55 @@ VOID HandleSignalledConnection(PCONNECTION_ENDPOINT Connection)
} }
} }
} }
ReferenceObject(Connection);
if (ClientInfo.Unlocked)
{
UnlockObjectFromDpcLevel(Connection);
KeReleaseSpinLock(&ClientInfo.Lock, ClientInfo.OldIrql);
}
else
{
UnlockObject(Connection, Connection->OldIrql);
}
while ((Entry = ExInterlockedRemoveHeadList(&Connection->CompletionQueue,
&Connection->Lock)))
{
Bucket = CONTAINING_RECORD(Entry, TDI_BUCKET, Entry);
Complete = Bucket->Request.RequestNotifyObject;
Complete(Bucket->Request.RequestContext, Bucket->Status, Bucket->Information);
ExFreePoolWithTag(Bucket, TDI_BUCKET_TAG);
}
if (!ClientInfo.Unlocked)
{
LockObject(Connection, &OldIrql);
}
else
{
KeAcquireSpinLock(&ClientInfo.Lock, &ClientInfo.OldIrql);
}
DereferenceObject(Connection);
/* If the socket is dead, remove the reference we added for oskit */
if (Connection->SignalState & SEL_FIN)
DereferenceObject(Connection);
} }
static VOID ConnectionFree(PVOID Object) {
VOID DrainSignals(VOID) { PCONNECTION_ENDPOINT Connection = Object;
PCONNECTION_ENDPOINT Connection;
PLIST_ENTRY CurrentEntry;
KIRQL OldIrql; KIRQL OldIrql;
KeAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql); TI_DbgPrint(DEBUG_TCP, ("Freeing TCP Endpoint\n"));
CurrentEntry = ConnectionEndpointListHead.Flink;
while (CurrentEntry != &ConnectionEndpointListHead)
{
Connection = CONTAINING_RECORD( CurrentEntry, CONNECTION_ENDPOINT,
ListEntry );
CurrentEntry = CurrentEntry->Flink;
KeReleaseSpinLock(&ConnectionEndpointListLock, OldIrql);
KeAcquireSpinLock(&Connection->Lock, &OldIrql); TcpipAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql);
if (Connection->SocketContext) RemoveEntryList(&Connection->ListEntry);
{ TcpipReleaseSpinLock(&ConnectionEndpointListLock, OldIrql);
HandleSignalledConnection(Connection);
KeReleaseSpinLock(&Connection->Lock, OldIrql);
ProcessCompletions(Connection); ExFreePoolWithTag( Connection, CONN_ENDPT_TAG );
}
else
{
KeReleaseSpinLock(&Connection->Lock, OldIrql);
}
KeAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql);
}
KeReleaseSpinLock(&ConnectionEndpointListLock, OldIrql);
} }
PCONNECTION_ENDPOINT TCPAllocateConnectionEndpoint( PVOID ClientContext ) { PCONNECTION_ENDPOINT TCPAllocateConnectionEndpoint( PVOID ClientContext ) {
@ -335,6 +336,10 @@ PCONNECTION_ENDPOINT TCPAllocateConnectionEndpoint( PVOID ClientContext ) {
/* Save client context pointer */ /* Save client context pointer */
Connection->ClientContext = ClientContext; Connection->ClientContext = ClientContext;
/* Add an extra reference for oskit */
Connection->RefCount = 2;
Connection->Free = ConnectionFree;
/* Add connection endpoint to global list */ /* Add connection endpoint to global list */
ExInterlockedInsertTailList(&ConnectionEndpointListHead, ExInterlockedInsertTailList(&ConnectionEndpointListHead,
&Connection->ListEntry, &Connection->ListEntry,
@ -343,24 +348,12 @@ PCONNECTION_ENDPOINT TCPAllocateConnectionEndpoint( PVOID ClientContext ) {
return Connection; return Connection;
} }
VOID TCPFreeConnectionEndpoint( PCONNECTION_ENDPOINT Connection ) {
KIRQL OldIrql;
TI_DbgPrint(DEBUG_TCP, ("Freeing TCP Endpoint\n"));
TcpipAcquireSpinLock(&ConnectionEndpointListLock, &OldIrql);
RemoveEntryList(&Connection->ListEntry);
TcpipReleaseSpinLock(&ConnectionEndpointListLock, OldIrql);
ExFreePoolWithTag( Connection, CONN_ENDPT_TAG );
}
NTSTATUS TCPSocket( PCONNECTION_ENDPOINT Connection, NTSTATUS TCPSocket( PCONNECTION_ENDPOINT Connection,
UINT Family, UINT Type, UINT Proto ) { UINT Family, UINT Type, UINT Proto ) {
NTSTATUS Status; NTSTATUS Status;
KIRQL OldIrql; KIRQL OldIrql;
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
TI_DbgPrint(DEBUG_TCP,("Called: Connection %x, Family %d, Type %d, " TI_DbgPrint(DEBUG_TCP,("Called: Connection %x, Family %d, Type %d, "
"Proto %d\n", "Proto %d\n",
@ -377,7 +370,7 @@ NTSTATUS TCPSocket( PCONNECTION_ENDPOINT Connection,
TI_DbgPrint(DEBUG_TCP,("Connection->SocketContext %x\n", TI_DbgPrint(DEBUG_TCP,("Connection->SocketContext %x\n",
Connection->SocketContext)); Connection->SocketContext));
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
return Status; return Status;
} }
@ -399,6 +392,7 @@ VOID TCPReceive(PIP_INTERFACE Interface, PIP_PACKET IPPacket)
KeAcquireSpinLock(&ClientInfo.Lock, &OldIrql); KeAcquireSpinLock(&ClientInfo.Lock, &OldIrql);
ClientInfo.Unlocked = TRUE; ClientInfo.Unlocked = TRUE;
ClientInfo.OldIrql = OldIrql;
OskitTCPReceiveDatagram( IPPacket->Header, OskitTCPReceiveDatagram( IPPacket->Header,
IPPacket->TotalSize, IPPacket->TotalSize,
@ -477,7 +471,6 @@ TimerThread(PVOID Context)
} }
TimerOskitTCP( Next == NextFast, Next == NextSlow ); TimerOskitTCP( Next == NextFast, Next == NextSlow );
DrainSignals();
Current = Next; Current = Next;
if (10 <= Current) { if (10 <= Current) {
@ -640,11 +633,11 @@ NTSTATUS TCPConnect
AddressToConnect.sin_family = AF_INET; AddressToConnect.sin_family = AF_INET;
AddressToBind = AddressToConnect; AddressToBind = AddressToConnect;
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
if (!Connection->AddressFile) if (!Connection->AddressFile)
{ {
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
@ -652,7 +645,7 @@ NTSTATUS TCPConnect
{ {
if (!(NCE = RouteGetRouteToDestination(&RemoteAddress))) if (!(NCE = RouteGetRouteToDestination(&RemoteAddress)))
{ {
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
return STATUS_NETWORK_UNREACHABLE; return STATUS_NETWORK_UNREACHABLE;
} }
@ -679,26 +672,24 @@ NTSTATUS TCPConnect
&AddressToConnect, &AddressToConnect,
sizeof(AddressToConnect) ) ); sizeof(AddressToConnect) ) );
KeReleaseSpinLock(&Connection->Lock, OldIrql);
if (Status == STATUS_PENDING) if (Status == STATUS_PENDING)
{ {
Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), TDI_BUCKET_TAG ); Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), TDI_BUCKET_TAG );
if( !Bucket ) if( !Bucket )
{ {
UnlockObject(Connection, OldIrql);
return STATUS_NO_MEMORY; return STATUS_NO_MEMORY;
} }
Bucket->Request.RequestNotifyObject = (PVOID)Complete; Bucket->Request.RequestNotifyObject = (PVOID)Complete;
Bucket->Request.RequestContext = Context; Bucket->Request.RequestContext = Context;
ExInterlockedInsertTailList( &Connection->ConnectRequest, &Bucket->Entry, InsertTailList( &Connection->ConnectRequest, &Bucket->Entry );
&Connection->Lock );
} }
} else {
KeReleaseSpinLock(&Connection->Lock, OldIrql);
} }
UnlockObject(Connection, OldIrql);
return Status; return Status;
} }
@ -714,7 +705,7 @@ NTSTATUS TCPDisconnect
TI_DbgPrint(DEBUG_TCP,("started\n")); TI_DbgPrint(DEBUG_TCP,("started\n"));
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
if (Flags & TDI_DISCONNECT_RELEASE) if (Flags & TDI_DISCONNECT_RELEASE)
Status = TCPTranslateError(OskitTCPDisconnect(Connection->SocketContext)); Status = TCPTranslateError(OskitTCPDisconnect(Connection->SocketContext));
@ -722,7 +713,7 @@ NTSTATUS TCPDisconnect
if ((Flags & TDI_DISCONNECT_ABORT) || !Flags) if ((Flags & TDI_DISCONNECT_ABORT) || !Flags)
Status = TCPTranslateError(OskitTCPShutdown(Connection->SocketContext, FWRITE | FREAD)); Status = TCPTranslateError(OskitTCPShutdown(Connection->SocketContext, FWRITE | FREAD));
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
TI_DbgPrint(DEBUG_TCP,("finished %x\n", Status)); TI_DbgPrint(DEBUG_TCP,("finished %x\n", Status));
@ -730,24 +721,36 @@ NTSTATUS TCPDisconnect
} }
NTSTATUS TCPClose NTSTATUS TCPClose
( PCONNECTION_ENDPOINT Connection ) { ( PCONNECTION_ENDPOINT Connection )
NTSTATUS Status; {
KIRQL OldIrql; KIRQL OldIrql;
NTSTATUS Status;
PVOID Socket; PVOID Socket;
TI_DbgPrint(DEBUG_TCP,("TCPClose started\n")); /* We don't rely on SocketContext == NULL for socket
* closure anymore but we still need it to determine
KeAcquireSpinLock(&Connection->Lock, &OldIrql); * if we caused the closure
*/
Socket = Connection->SocketContext; Socket = Connection->SocketContext;
Connection->SocketContext = NULL; Connection->SocketContext = NULL;
/* We need to close here otherwise oskit will never indicate
* SEL_FIN and we will never fully close the connection
*/
LockObject(Connection, &OldIrql);
Status = TCPTranslateError( OskitTCPClose( Socket ) ); Status = TCPTranslateError( OskitTCPClose( Socket ) );
UnlockObject(Connection, OldIrql);
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
Connection->SocketContext = Socket; Connection->SocketContext = Socket;
return Status;
} }
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(DEBUG_TCP,("TCPClose finished %x\n", Status)); if (Connection->AddressFile)
DereferenceObject(Connection->AddressFile);
DereferenceObject(Connection);
return Status; return Status;
} }
@ -773,9 +776,7 @@ NTSTATUS TCPReceiveData
TI_DbgPrint(DEBUG_TCP,("TCP>|< Got an MDL %x (%x:%d)\n", Buffer, DataBuffer, DataLen)); TI_DbgPrint(DEBUG_TCP,("TCP>|< Got an MDL %x (%x:%d)\n", Buffer, DataBuffer, DataLen));
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
ASSERT_KM_POINTER(Connection->SocketContext);
Status = TCPTranslateError Status = TCPTranslateError
( OskitTCPRecv ( OskitTCPRecv
@ -785,8 +786,6 @@ NTSTATUS TCPReceiveData
&Received, &Received,
ReceiveFlags ) ); ReceiveFlags ) );
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(DEBUG_TCP,("OskitTCPReceive: %x, %d\n", Status, Received)); TI_DbgPrint(DEBUG_TCP,("OskitTCPReceive: %x, %d\n", Status, Received));
/* Keep this request around ... there was no data yet */ /* Keep this request around ... there was no data yet */
@ -795,6 +794,7 @@ NTSTATUS TCPReceiveData
Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), TDI_BUCKET_TAG ); Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), TDI_BUCKET_TAG );
if( !Bucket ) { if( !Bucket ) {
TI_DbgPrint(DEBUG_TCP,("Failed to allocate bucket\n")); TI_DbgPrint(DEBUG_TCP,("Failed to allocate bucket\n"));
UnlockObject(Connection, OldIrql);
return STATUS_NO_MEMORY; return STATUS_NO_MEMORY;
} }
@ -802,14 +802,15 @@ NTSTATUS TCPReceiveData
Bucket->Request.RequestContext = Context; Bucket->Request.RequestContext = Context;
*BytesReceived = 0; *BytesReceived = 0;
ExInterlockedInsertTailList( &Connection->ReceiveRequest, &Bucket->Entry, InsertTailList( &Connection->ReceiveRequest, &Bucket->Entry );
&Connection->Lock );
TI_DbgPrint(DEBUG_TCP,("Queued read irp\n")); TI_DbgPrint(DEBUG_TCP,("Queued read irp\n"));
} else { } else {
TI_DbgPrint(DEBUG_TCP,("Got status %x, bytes %d\n", Status, Received)); TI_DbgPrint(DEBUG_TCP,("Got status %x, bytes %d\n", Status, Received));
*BytesReceived = Received; *BytesReceived = Received;
} }
UnlockObject(Connection, OldIrql);
TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status)); TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
return Status; return Status;
@ -828,13 +829,11 @@ NTSTATUS TCPSendData
PTDI_BUCKET Bucket; PTDI_BUCKET Bucket;
KIRQL OldIrql; KIRQL OldIrql;
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
TI_DbgPrint(DEBUG_TCP,("Called for %d bytes (on socket %x)\n", TI_DbgPrint(DEBUG_TCP,("Called for %d bytes (on socket %x)\n",
SendLength, Connection->SocketContext)); SendLength, Connection->SocketContext));
ASSERT_KM_POINTER(Connection->SocketContext);
TI_DbgPrint(DEBUG_TCP,("Connection = %x\n", Connection)); TI_DbgPrint(DEBUG_TCP,("Connection = %x\n", Connection));
TI_DbgPrint(DEBUG_TCP,("Connection->SocketContext = %x\n", TI_DbgPrint(DEBUG_TCP,("Connection->SocketContext = %x\n",
Connection->SocketContext)); Connection->SocketContext));
@ -844,8 +843,6 @@ NTSTATUS TCPSendData
(OSK_PCHAR)BufferData, SendLength, (OSK_PCHAR)BufferData, SendLength,
&Sent, 0 ) ); &Sent, 0 ) );
KeReleaseSpinLock(&Connection->Lock, OldIrql);
TI_DbgPrint(DEBUG_TCP,("OskitTCPSend: %x, %d\n", Status, Sent)); TI_DbgPrint(DEBUG_TCP,("OskitTCPSend: %x, %d\n", Status, Sent));
/* Keep this request around ... there was no data yet */ /* Keep this request around ... there was no data yet */
@ -853,6 +850,7 @@ NTSTATUS TCPSendData
/* Freed in TCPSocketState */ /* Freed in TCPSocketState */
Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), TDI_BUCKET_TAG ); Bucket = ExAllocatePoolWithTag( NonPagedPool, sizeof(*Bucket), TDI_BUCKET_TAG );
if( !Bucket ) { if( !Bucket ) {
UnlockObject(Connection, OldIrql);
TI_DbgPrint(DEBUG_TCP,("Failed to allocate bucket\n")); TI_DbgPrint(DEBUG_TCP,("Failed to allocate bucket\n"));
return STATUS_NO_MEMORY; return STATUS_NO_MEMORY;
} }
@ -861,14 +859,15 @@ NTSTATUS TCPSendData
Bucket->Request.RequestContext = Context; Bucket->Request.RequestContext = Context;
*BytesSent = 0; *BytesSent = 0;
ExInterlockedInsertTailList( &Connection->SendRequest, &Bucket->Entry, InsertTailList( &Connection->SendRequest, &Bucket->Entry );
&Connection->Lock );
TI_DbgPrint(DEBUG_TCP,("Queued write irp\n")); TI_DbgPrint(DEBUG_TCP,("Queued write irp\n"));
} else { } else {
TI_DbgPrint(DEBUG_TCP,("Got status %x, bytes %d\n", Status, Sent)); TI_DbgPrint(DEBUG_TCP,("Got status %x, bytes %d\n", Status, Sent));
*BytesSent = Sent; *BytesSent = Sent;
} }
UnlockObject(Connection, OldIrql);
TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status)); TI_DbgPrint(DEBUG_TCP,("Status %x\n", Status));
return Status; return Status;
@ -899,13 +898,13 @@ NTSTATUS TCPGetSockAddress
NTSTATUS Status; NTSTATUS Status;
KIRQL OldIrql; KIRQL OldIrql;
KeAcquireSpinLock(&Connection->Lock, &OldIrql); LockObject(Connection, &OldIrql);
Status = TCPTranslateError(OskitTCPGetAddress(Connection->SocketContext, Status = TCPTranslateError(OskitTCPGetAddress(Connection->SocketContext,
&LocalAddress, &LocalPort, &LocalAddress, &LocalPort,
&RemoteAddress, &RemotePort)); &RemoteAddress, &RemotePort));
KeReleaseSpinLock(&Connection->Lock, OldIrql); UnlockObject(Connection, OldIrql);
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
return Status; return Status;
@ -932,7 +931,7 @@ BOOLEAN TCPRemoveIRP( PCONNECTION_ENDPOINT Endpoint, PIRP Irp ) {
ListHead[2] = &Endpoint->ConnectRequest; ListHead[2] = &Endpoint->ConnectRequest;
ListHead[3] = &Endpoint->ListenRequest; ListHead[3] = &Endpoint->ListenRequest;
TcpipAcquireSpinLock( &Endpoint->Lock, &OldIrql ); LockObject(Endpoint, &OldIrql);
for( i = 0; i < 4; i++ ) for( i = 0; i < 4; i++ )
{ {
@ -951,7 +950,7 @@ BOOLEAN TCPRemoveIRP( PCONNECTION_ENDPOINT Endpoint, PIRP Irp ) {
} }
} }
TcpipReleaseSpinLock( &Endpoint->Lock, OldIrql ); UnlockObject(Endpoint, OldIrql);
return Found; return Found;
} }

View file

@ -174,7 +174,7 @@ NTSTATUS UDPSendDatagram(
PNEIGHBOR_CACHE_ENTRY NCE; PNEIGHBOR_CACHE_ENTRY NCE;
KIRQL OldIrql; KIRQL OldIrql;
KeAcquireSpinLock(&AddrFile->Lock, &OldIrql); LockObject(AddrFile, &OldIrql);
TI_DbgPrint(MID_TRACE,("Sending Datagram(%x %x %x %d)\n", TI_DbgPrint(MID_TRACE,("Sending Datagram(%x %x %x %d)\n",
AddrFile, ConnInfo, BufferData, DataSize)); AddrFile, ConnInfo, BufferData, DataSize));
@ -189,7 +189,7 @@ NTSTATUS UDPSendDatagram(
break; break;
default: default:
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_UNSUCCESSFUL; return STATUS_UNSUCCESSFUL;
} }
@ -201,7 +201,7 @@ NTSTATUS UDPSendDatagram(
* interface we're sending over * interface we're sending over
*/ */
if(!(NCE = RouteGetRouteToDestination( &RemoteAddress ))) { if(!(NCE = RouteGetRouteToDestination( &RemoteAddress ))) {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_NETWORK_UNREACHABLE; return STATUS_NETWORK_UNREACHABLE;
} }
@ -210,7 +210,7 @@ NTSTATUS UDPSendDatagram(
else else
{ {
if(!(NCE = NBLocateNeighbor( &LocalAddress ))) { if(!(NCE = NBLocateNeighbor( &LocalAddress ))) {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
} }
@ -226,18 +226,18 @@ NTSTATUS UDPSendDatagram(
if( !NT_SUCCESS(Status) ) if( !NT_SUCCESS(Status) )
{ {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return Status; return Status;
} }
if (!NT_SUCCESS(Status = IPSendDatagram( &Packet, NCE, UDPSendPacketComplete, NULL ))) if (!NT_SUCCESS(Status = IPSendDatagram( &Packet, NCE, UDPSendPacketComplete, NULL )))
{ {
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
FreeNdisPacket(Packet.NdisPacket); FreeNdisPacket(Packet.NdisPacket);
return Status; return Status;
} }
KeReleaseSpinLock(&AddrFile->Lock, OldIrql); UnlockObject(AddrFile, OldIrql);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }

View file

@ -285,16 +285,11 @@ int OskitTCPShutdown( void *socket, int disconn_type ) {
int OskitTCPClose( void *socket ) { int OskitTCPClose( void *socket ) {
int error; int error;
struct socket *so = socket;
if (!socket) if (!socket)
return OSK_ESHUTDOWN; return OSK_ESHUTDOWN;
OSKLock(); OSKLock();
/* We have to remove the socket context here otherwise we end up
* back in HandleSignalledConnection with a freed connection context
*/
so->so_connection = NULL;
error = soclose( socket ); error = soclose( socket );
OSKUnlock(); OSKUnlock();

View file

@ -39,8 +39,8 @@ void wakeup( struct socket *so, void *token ) {
OS_DbgPrint(OSK_MID_TRACE,("Socket writeable\n")); OS_DbgPrint(OSK_MID_TRACE,("Socket writeable\n"));
flags |= SEL_WRITE; flags |= SEL_WRITE;
} }
if( so->so_state & SS_CANTRCVMORE ) { if (!so->so_pcb) {
OS_DbgPrint(OSK_MID_TRACE,("Socket can't be read any longer\n")); OS_DbgPrint(OSK_MID_TRACE,("Socket dying\n"));
flags |= SEL_FIN; flags |= SEL_FIN;
} }