- Add locking to protect the socket list

svn path=/trunk/; revision=47803
This commit is contained in:
Cameron Gutman 2010-06-19 05:04:40 +00:00
parent a8c547f091
commit 135340e065

View file

@ -23,6 +23,7 @@ HANDLE GlobalHeap;
WSPUPCALLTABLE Upcalls; WSPUPCALLTABLE Upcalls;
LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest; LPWPUCOMPLETEOVERLAPPEDREQUEST lpWPUCompleteOverlappedRequest;
PSOCKET_INFORMATION SocketListHead = NULL; PSOCKET_INFORMATION SocketListHead = NULL;
CRITICAL_SECTION SocketListLock;
LIST_ENTRY SockHelpersListHead = { NULL, NULL }; LIST_ENTRY SockHelpersListHead = { NULL, NULL };
ULONG SockAsyncThreadRefCount; ULONG SockAsyncThreadRefCount;
HANDLE SockAsyncHelperAfdHandle; HANDLE SockAsyncHelperAfdHandle;
@ -280,8 +281,10 @@ WSPSocket(int AddressFamily,
NULL); NULL);
/* Save in Process Sockets List */ /* Save in Process Sockets List */
EnterCriticalSection(&SocketListLock);
Socket->NextSocket = SocketListHead; Socket->NextSocket = SocketListHead;
SocketListHead = Socket; SocketListHead = Socket;
LeaveCriticalSection(&SocketListLock);
/* Create the Socket Context */ /* Create the Socket Context */
CreateContext(Socket); CreateContext(Socket);
@ -556,6 +559,7 @@ WSPCloseSocket(IN SOCKET Handle,
NtClose(Socket->TdiConnectionHandle); NtClose(Socket->TdiConnectionHandle);
Socket->TdiConnectionHandle = NULL; Socket->TdiConnectionHandle = NULL;
EnterCriticalSection(&SocketListLock);
if (SocketListHead == Socket) if (SocketListHead == Socket)
{ {
SocketListHead = SocketListHead->NextSocket; SocketListHead = SocketListHead->NextSocket;
@ -574,6 +578,7 @@ WSPCloseSocket(IN SOCKET Handle,
CurrentSocket = CurrentSocket->NextSocket; CurrentSocket = CurrentSocket->NextSocket;
} }
} }
LeaveCriticalSection(&SocketListLock);
HeapFree(GlobalHeap, 0, Socket); HeapFree(GlobalHeap, 0, Socket);
@ -2314,15 +2319,22 @@ GetSocketStructure(SOCKET Handle)
{ {
PSOCKET_INFORMATION CurrentSocket; PSOCKET_INFORMATION CurrentSocket;
EnterCriticalSection(&SocketListLock);
CurrentSocket = SocketListHead; CurrentSocket = SocketListHead;
while (CurrentSocket) while (CurrentSocket)
{ {
if (CurrentSocket->Handle == Handle) if (CurrentSocket->Handle == Handle)
{
LeaveCriticalSection(&SocketListLock);
return CurrentSocket; return CurrentSocket;
}
CurrentSocket = CurrentSocket->NextSocket; CurrentSocket = CurrentSocket->NextSocket;
} }
LeaveCriticalSection(&SocketListLock);
return NULL; return NULL;
} }
@ -2841,6 +2853,9 @@ DllMain(HANDLE hInstDll,
/* Heap to use when allocating */ /* Heap to use when allocating */
GlobalHeap = GetProcessHeap(); GlobalHeap = GetProcessHeap();
/* Initialize the lock that protects our socket list */
InitializeCriticalSection(&SocketListLock);
AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n")); AFD_DbgPrint(MAX_TRACE, ("MSAFD.DLL has been loaded\n"));
break; break;
@ -2852,6 +2867,10 @@ DllMain(HANDLE hInstDll,
break; break;
case DLL_PROCESS_DETACH: case DLL_PROCESS_DETACH:
/* Delete the socket list lock */
DeleteCriticalSection(&SocketListLock);
break; break;
} }