AcceptSockets hold a reference to their ListenSockets.

So accepted connections stay up when the listen socket
is closed.
This commit is contained in:
Johannes Thoma 2025-03-26 15:43:16 +00:00
parent 21aa24c6b8
commit 871eb35068

View file

@ -51,8 +51,8 @@ ULONG DebugTraceLevel = MIN_TRACE;
#define AFD_SHARE_EXCLUSIVE 0x3L #define AFD_SHARE_EXCLUSIVE 0x3L
/* Function trace */ /* Function trace */
#define FUNCTION_TRACE DbgPrint("Function %s ...\n", __func__) // #define FUNCTION_TRACE DbgPrint("Function %s ...\n", __func__)
// #define FUNCTION_TRACE do { } while (0) #define FUNCTION_TRACE do { } while (0)
#if 0 #if 0
@ -80,8 +80,6 @@ ULONG DebugTraceLevel = MIN_TRACE;
#endif #endif
struct ListenContext;
typedef struct _WSK_SOCKET_INTERNAL typedef struct _WSK_SOCKET_INTERNAL
{ {
WSK_SOCKET s; WSK_SOCKET s;
@ -119,7 +117,12 @@ typedef struct _WSK_SOCKET_INTERNAL
PKTHREAD ListenThread; PKTHREAD ListenThread;
KEVENT StartListenEvent; KEVENT StartListenEvent;
BOOLEAN ListenThreadShouldRun; BOOLEAN ListenThreadShouldRun;
struct ListenContext *l;
/* AcceptSocket's keep a reference on their listen sockets so
* that the Address file will be closed only if there are no
* open connections any more.
*/
struct _WSK_SOCKET_INTERNAL *ListenSocket;
} WSK_SOCKET_INTERNAL, *PWSK_SOCKET_INTERNAL; } WSK_SOCKET_INTERNAL, *PWSK_SOCKET_INTERNAL;
struct NetioContext struct NetioContext
@ -139,6 +142,9 @@ SocketGet(PWSK_SOCKET_INTERNAL s)
// DbgPrint("SocketGet: refcount is %d socket is %p\n", s->RefCount, s); // DbgPrint("SocketGet: refcount is %d socket is %p\n", s->RefCount, s);
} }
void
SocketPut(PWSK_SOCKET_INTERNAL s);
static static
void SocketShutdown(PWSK_SOCKET_INTERNAL s) void SocketShutdown(PWSK_SOCKET_INTERNAL s)
{ {
@ -146,6 +152,11 @@ void SocketShutdown(PWSK_SOCKET_INTERNAL s)
FUNCTION_TRACE; FUNCTION_TRACE;
if (s->ListenSocket != NULL)
{
SocketPut(s->ListenSocket);
s->ListenSocket = NULL;
}
if (s->ListenThreadHandle != NULL) if (s->ListenThreadHandle != NULL)
{ {
s->ListenThreadShouldRun = FALSE; s->ListenThreadShouldRun = FALSE;
@ -194,11 +205,14 @@ SocketPut(PWSK_SOCKET_INTERNAL s)
{ {
SocketShutdown(s); /* noop when called twice */ SocketShutdown(s); /* noop when called twice */
if (s->ConnectionHandle != NULL) /* Especially listen sockets must keep the LocalAddressFile
{ * open when being closed and there are still accepted sockets
ZwClose(s->ConnectionHandle); * somewhere. Else the accepted socket I/O will fail with
s->ConnectionHandle = NULL; * c0000184 (invalid device state). So do not close the
} * address file in Shutdown only close it here when all AcceptSockets
* are gone (AcceptSockets hold a reference to the listen socket).
*/
if (s->LocalAddressFile != NULL) { if (s->LocalAddressFile != NULL) {
ObDereferenceObject(s->LocalAddressFile); ObDereferenceObject(s->LocalAddressFile);
s->LocalAddressFile = NULL; s->LocalAddressFile = NULL;
@ -329,10 +343,8 @@ static NTSTATUS CreateSocket(
IoSetCompletionRoutine(NewSocketIrp, CompletionFireEvent, &CompletionEvent, TRUE, TRUE, TRUE); IoSetCompletionRoutine(NewSocketIrp, CompletionFireEvent, &CompletionEvent, TRUE, TRUE, TRUE);
NewSocketIrp->Tail.Overlay.Thread = PsGetCurrentThread(); NewSocketIrp->Tail.Overlay.Thread = PsGetCurrentThread();
DbgPrint("into WskSocket ...\n");
Status = WskSocket(NULL, AddressFamily, SocketType, Protocol, Flags, Status = WskSocket(NULL, AddressFamily, SocketType, Protocol, Flags,
NULL, NULL, NULL, NULL, NULL, NewSocketIrp); NULL, NULL, NULL, NULL, NULL, NewSocketIrp);
DbgPrint("out of WskSocket ...\n");
if (Status == STATUS_PENDING) { if (Status == STATUS_PENDING) {
KeWaitForSingleObject(&CompletionEvent, Executive, KernelMode, FALSE, NULL); KeWaitForSingleObject(&CompletionEvent, Executive, KernelMode, FALSE, NULL);
@ -390,8 +402,6 @@ ListenComplete(PDEVICE_OBJECT DeviceObject, PIRP Irp, PVOID Context)
memcpy(&AcceptSocket->RemoteAddress, RemoteAddress, sizeof(AcceptSocket->RemoteAddress)); memcpy(&AcceptSocket->RemoteAddress, RemoteAddress, sizeof(AcceptSocket->RemoteAddress));
} }
/* HACK!! We have a backpointer for use of TdiAccept() in RequeueListenThread() do not free l */
ListenSocket->l = l;
/* And wait for the next incoming connection. */ /* And wait for the next incoming connection. */
/* This is done in a separate thread at IRQL = 0 */ /* This is done in a separate thread at IRQL = 0 */
QueueListening(ListenSocket); QueueListening(ListenSocket);
@ -401,7 +411,7 @@ ListenComplete(PDEVICE_OBJECT DeviceObject, PIRP Irp, PVOID Context)
SocketPut(AcceptSocket); SocketPut(AcceptSocket);
SocketPut(ListenSocket); SocketPut(ListenSocket);
// TODO: free ReturnConnectionInfo, RequestConnectionInfo // TODO: free ReturnConnectionInfo, RequestConnectionInfo
// ExFreePoolWithTag(l, TAG_NETIO); ExFreePoolWithTag(l, TAG_NETIO);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
@ -413,7 +423,6 @@ StartListening(PWSK_SOCKET_INTERNAL ListenSocket)
NTSTATUS status; NTSTATUS status;
struct ListenContext *lc; struct ListenContext *lc;
PWSK_SOCKET_INTERNAL AcceptSocket; PWSK_SOCKET_INTERNAL AcceptSocket;
// PIRP AcceptIrp;
FUNCTION_TRACE; FUNCTION_TRACE;
@ -423,14 +432,15 @@ StartListening(PWSK_SOCKET_INTERNAL ListenSocket)
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
DbgPrint("into CreateSocket ...\n");
status = CreateSocket(ListenSocket->family, ListenSocket->type, ListenSocket->proto, WSK_FLAG_CONNECTION_SOCKET, &AcceptSocket); status = CreateSocket(ListenSocket->family, ListenSocket->type, ListenSocket->proto, WSK_FLAG_CONNECTION_SOCKET, &AcceptSocket);
DbgPrint("out of CreateSocket ...\n");
if (status != STATUS_SUCCESS) if (status != STATUS_SUCCESS)
{ {
DbgPrint("Could not create AcceptSocket, status is 0x%08x\n", status); DbgPrint("Could not create AcceptSocket, status is 0x%08x\n", status);
return status; return status;
} }
AcceptSocket->ListenSocket = ListenSocket;
/* Put when the AcceptSocket is closed */
SocketGet(AcceptSocket->ListenSocket);
status = STATUS_INSUFFICIENT_RESOURCES; status = STATUS_INSUFFICIENT_RESOURCES;
lc = ExAllocatePoolWithTag(NonPagedPool, sizeof(*lc), TAG_NETIO); lc = ExAllocatePoolWithTag(NonPagedPool, sizeof(*lc), TAG_NETIO);
@ -479,15 +489,6 @@ DbgPrint("out of CreateSocket ...\n");
} }
ListenSocket->ListenIrp = tdiIrp; ListenSocket->ListenIrp = tdiIrp;
#if 0
AcceptIrp = NULL;
status = TdiAccept(&AcceptIrp, AcceptSocket->ConnectionFile, lc->RequestConnectionInfo, lc->ReturnConnectionInfo, AcceptComplete, AcceptSocket);
if (!NT_SUCCESS(status))
{
DbgPrint("TdiAccept returned non-successful status 0x%08x\n", status);
}
#endif
return STATUS_PENDING; return STATUS_PENDING;
err_out_free_lc_and_req_conn_info: err_out_free_lc_and_req_conn_info:
@ -536,26 +537,6 @@ static void WSKAPI RequeueListenThread(void *p)
{ {
break; break;
} }
#if 0
/* We need to call TdiAccept here at IRQL == 0 */
if (ListenSocket->l != NULL)
{
AcceptIrp = NULL;
AcceptSocket = ListenSocket->l->AcceptSocket;
/* HACK: ReturnConnectionInfo already filled out with the remote address */
status = TdiAccept(&AcceptIrp, AcceptSocket->ConnectionFile, ListenSocket->l->ReturnConnectionInfo, NULL, AcceptComplete, AcceptSocket);
if (!NT_SUCCESS(status))
{
DbgPrint("TdiAccept returned non-successful status 0x%08x\n", status);
}
ListenSocket->l = NULL;
}
#endif
/* From here, ListenSocket->l may become invalid ... */
StartListening(ListenSocket); StartListening(ListenSocket);
} }
} }
@ -1267,6 +1248,7 @@ WskSocket(
s->ListenThreadHandle = NULL; s->ListenThreadHandle = NULL;
s->ListenThreadShouldRun = FALSE; s->ListenThreadShouldRun = FALSE;
s->ListenCancelled = FALSE; s->ListenCancelled = FALSE;
s->ListenSocket = NULL;
memset(&s->LocalAddress, 0, sizeof(s->LocalAddress)); memset(&s->LocalAddress, 0, sizeof(s->LocalAddress));
memset(&s->RemoteAddress, 0, sizeof(s->RemoteAddress)); memset(&s->RemoteAddress, 0, sizeof(s->RemoteAddress));
@ -1282,9 +1264,7 @@ WskSocket(
if (Flags != WSK_FLAG_LISTEN_SOCKET) if (Flags != WSK_FLAG_LISTEN_SOCKET)
{ {
DbgPrint("into TdiOpenConnectionEndpointFile ...\n");
status = TdiOpenConnectionEndpointFile(&s->TdiName, &s->ConnectionHandle, &s->ConnectionFile); status = TdiOpenConnectionEndpointFile(&s->TdiName, &s->ConnectionHandle, &s->ConnectionFile);
DbgPrint("out of TdiOpenConnectionEndpointFile ...\n");
if (status != STATUS_SUCCESS) if (status != STATUS_SUCCESS)
{ {
DbgPrint("Could not open TDI handle, status is %x\n", status); DbgPrint("Could not open TDI handle, status is %x\n", status);
@ -1296,7 +1276,6 @@ DbgPrint("out of TdiOpenConnectionEndpointFile ...\n");
{ {
KeInitializeEvent(&s->StartListenEvent, SynchronizationEvent, FALSE); KeInitializeEvent(&s->StartListenEvent, SynchronizationEvent, FALSE);
s->ListenThreadShouldRun = TRUE; s->ListenThreadShouldRun = TRUE;
s->l = NULL;
status = PsCreateSystemThread(&s->ListenThreadHandle, THREAD_ALL_ACCESS, NULL, NULL, NULL, RequeueListenThread, s); status = PsCreateSystemThread(&s->ListenThreadHandle, THREAD_ALL_ACCESS, NULL, NULL, NULL, RequeueListenThread, s);
if (status != STATUS_SUCCESS) if (status != STATUS_SUCCESS)