From 871eb35068cde8ff01de9a5be7e11c1262a5feb7 Mon Sep 17 00:00:00 2001 From: Johannes Thoma Date: Wed, 26 Mar 2025 15:43:16 +0000 Subject: [PATCH] AcceptSockets hold a reference to their ListenSockets. So accepted connections stay up when the listen socket is closed. --- drivers/network/netio/netio.c | 79 +++++++++++++---------------------- 1 file changed, 29 insertions(+), 50 deletions(-) diff --git a/drivers/network/netio/netio.c b/drivers/network/netio/netio.c index d08b67fc3d2..51b4ce228ff 100644 --- a/drivers/network/netio/netio.c +++ b/drivers/network/netio/netio.c @@ -51,8 +51,8 @@ ULONG DebugTraceLevel = MIN_TRACE; #define AFD_SHARE_EXCLUSIVE 0x3L /* Function trace */ -#define FUNCTION_TRACE DbgPrint("Function %s ...\n", __func__) -// #define FUNCTION_TRACE do { } while (0) +// #define FUNCTION_TRACE DbgPrint("Function %s ...\n", __func__) +#define FUNCTION_TRACE do { } while (0) #if 0 @@ -80,8 +80,6 @@ ULONG DebugTraceLevel = MIN_TRACE; #endif -struct ListenContext; - typedef struct _WSK_SOCKET_INTERNAL { WSK_SOCKET s; @@ -119,7 +117,12 @@ typedef struct _WSK_SOCKET_INTERNAL PKTHREAD ListenThread; KEVENT StartListenEvent; BOOLEAN ListenThreadShouldRun; - struct ListenContext *l; + + /* AcceptSocket's keep a reference on their listen sockets so + * that the Address file will be closed only if there are no + * open connections any more. + */ + struct _WSK_SOCKET_INTERNAL *ListenSocket; } WSK_SOCKET_INTERNAL, *PWSK_SOCKET_INTERNAL; struct NetioContext @@ -139,6 +142,9 @@ SocketGet(PWSK_SOCKET_INTERNAL s) // DbgPrint("SocketGet: refcount is %d socket is %p\n", s->RefCount, s); } +void +SocketPut(PWSK_SOCKET_INTERNAL s); + static void SocketShutdown(PWSK_SOCKET_INTERNAL s) { @@ -146,6 +152,11 @@ void SocketShutdown(PWSK_SOCKET_INTERNAL s) FUNCTION_TRACE; + if (s->ListenSocket != NULL) + { + SocketPut(s->ListenSocket); + s->ListenSocket = NULL; + } if (s->ListenThreadHandle != NULL) { s->ListenThreadShouldRun = FALSE; @@ -194,11 +205,14 @@ SocketPut(PWSK_SOCKET_INTERNAL s) { SocketShutdown(s); /* noop when called twice */ - if (s->ConnectionHandle != NULL) - { - ZwClose(s->ConnectionHandle); - s->ConnectionHandle = NULL; - } + /* Especially listen sockets must keep the LocalAddressFile + * open when being closed and there are still accepted sockets + * somewhere. Else the accepted socket I/O will fail with + * c0000184 (invalid device state). So do not close the + * address file in Shutdown only close it here when all AcceptSockets + * are gone (AcceptSockets hold a reference to the listen socket). + */ + if (s->LocalAddressFile != NULL) { ObDereferenceObject(s->LocalAddressFile); s->LocalAddressFile = NULL; @@ -329,10 +343,8 @@ static NTSTATUS CreateSocket( IoSetCompletionRoutine(NewSocketIrp, CompletionFireEvent, &CompletionEvent, TRUE, TRUE, TRUE); NewSocketIrp->Tail.Overlay.Thread = PsGetCurrentThread(); -DbgPrint("into WskSocket ...\n"); Status = WskSocket(NULL, AddressFamily, SocketType, Protocol, Flags, NULL, NULL, NULL, NULL, NULL, NewSocketIrp); -DbgPrint("out of WskSocket ...\n"); if (Status == STATUS_PENDING) { KeWaitForSingleObject(&CompletionEvent, Executive, KernelMode, FALSE, NULL); @@ -390,8 +402,6 @@ ListenComplete(PDEVICE_OBJECT DeviceObject, PIRP Irp, PVOID Context) memcpy(&AcceptSocket->RemoteAddress, RemoteAddress, sizeof(AcceptSocket->RemoteAddress)); } - /* HACK!! We have a backpointer for use of TdiAccept() in RequeueListenThread() do not free l */ - ListenSocket->l = l; /* And wait for the next incoming connection. */ /* This is done in a separate thread at IRQL = 0 */ QueueListening(ListenSocket); @@ -401,7 +411,7 @@ ListenComplete(PDEVICE_OBJECT DeviceObject, PIRP Irp, PVOID Context) SocketPut(AcceptSocket); SocketPut(ListenSocket); // TODO: free ReturnConnectionInfo, RequestConnectionInfo -// ExFreePoolWithTag(l, TAG_NETIO); + ExFreePoolWithTag(l, TAG_NETIO); return STATUS_SUCCESS; } @@ -413,7 +423,6 @@ StartListening(PWSK_SOCKET_INTERNAL ListenSocket) NTSTATUS status; struct ListenContext *lc; PWSK_SOCKET_INTERNAL AcceptSocket; -// PIRP AcceptIrp; FUNCTION_TRACE; @@ -423,14 +432,15 @@ StartListening(PWSK_SOCKET_INTERNAL ListenSocket) return STATUS_INVALID_PARAMETER; } -DbgPrint("into CreateSocket ...\n"); status = CreateSocket(ListenSocket->family, ListenSocket->type, ListenSocket->proto, WSK_FLAG_CONNECTION_SOCKET, &AcceptSocket); -DbgPrint("out of CreateSocket ...\n"); if (status != STATUS_SUCCESS) { DbgPrint("Could not create AcceptSocket, status is 0x%08x\n", status); return status; } + AcceptSocket->ListenSocket = ListenSocket; + /* Put when the AcceptSocket is closed */ + SocketGet(AcceptSocket->ListenSocket); status = STATUS_INSUFFICIENT_RESOURCES; lc = ExAllocatePoolWithTag(NonPagedPool, sizeof(*lc), TAG_NETIO); @@ -479,15 +489,6 @@ DbgPrint("out of CreateSocket ...\n"); } ListenSocket->ListenIrp = tdiIrp; -#if 0 - AcceptIrp = NULL; - status = TdiAccept(&AcceptIrp, AcceptSocket->ConnectionFile, lc->RequestConnectionInfo, lc->ReturnConnectionInfo, AcceptComplete, AcceptSocket); - - if (!NT_SUCCESS(status)) - { - DbgPrint("TdiAccept returned non-successful status 0x%08x\n", status); - } -#endif return STATUS_PENDING; err_out_free_lc_and_req_conn_info: @@ -536,26 +537,6 @@ static void WSKAPI RequeueListenThread(void *p) { break; } - -#if 0 - /* We need to call TdiAccept here at IRQL == 0 */ - if (ListenSocket->l != NULL) - { - AcceptIrp = NULL; - AcceptSocket = ListenSocket->l->AcceptSocket; - -/* HACK: ReturnConnectionInfo already filled out with the remote address */ - status = TdiAccept(&AcceptIrp, AcceptSocket->ConnectionFile, ListenSocket->l->ReturnConnectionInfo, NULL, AcceptComplete, AcceptSocket); - - if (!NT_SUCCESS(status)) - { - DbgPrint("TdiAccept returned non-successful status 0x%08x\n", status); - } - ListenSocket->l = NULL; - } -#endif - - /* From here, ListenSocket->l may become invalid ... */ StartListening(ListenSocket); } } @@ -1267,6 +1248,7 @@ WskSocket( s->ListenThreadHandle = NULL; s->ListenThreadShouldRun = FALSE; s->ListenCancelled = FALSE; + s->ListenSocket = NULL; memset(&s->LocalAddress, 0, sizeof(s->LocalAddress)); memset(&s->RemoteAddress, 0, sizeof(s->RemoteAddress)); @@ -1282,9 +1264,7 @@ WskSocket( if (Flags != WSK_FLAG_LISTEN_SOCKET) { -DbgPrint("into TdiOpenConnectionEndpointFile ...\n"); status = TdiOpenConnectionEndpointFile(&s->TdiName, &s->ConnectionHandle, &s->ConnectionFile); -DbgPrint("out of TdiOpenConnectionEndpointFile ...\n"); if (status != STATUS_SUCCESS) { DbgPrint("Could not open TDI handle, status is %x\n", status); @@ -1296,7 +1276,6 @@ DbgPrint("out of TdiOpenConnectionEndpointFile ...\n"); { KeInitializeEvent(&s->StartListenEvent, SynchronizationEvent, FALSE); s->ListenThreadShouldRun = TRUE; - s->l = NULL; status = PsCreateSystemThread(&s->ListenThreadHandle, THREAD_ALL_ACCESS, NULL, NULL, NULL, RequeueListenThread, s); if (status != STATUS_SUCCESS)