AcceptSockets hold a reference to their ListenSockets.

So accepted connections stay up when the listen socket
is closed.
This commit is contained in:
Johannes Thoma 2025-03-26 15:43:16 +00:00
parent 21aa24c6b8
commit 871eb35068

View file

@ -51,8 +51,8 @@ ULONG DebugTraceLevel = MIN_TRACE;
#define AFD_SHARE_EXCLUSIVE 0x3L
/* Function trace */
#define FUNCTION_TRACE DbgPrint("Function %s ...\n", __func__)
// #define FUNCTION_TRACE do { } while (0)
// #define FUNCTION_TRACE DbgPrint("Function %s ...\n", __func__)
#define FUNCTION_TRACE do { } while (0)
#if 0
@ -80,8 +80,6 @@ ULONG DebugTraceLevel = MIN_TRACE;
#endif
struct ListenContext;
typedef struct _WSK_SOCKET_INTERNAL
{
WSK_SOCKET s;
@ -119,7 +117,12 @@ typedef struct _WSK_SOCKET_INTERNAL
PKTHREAD ListenThread;
KEVENT StartListenEvent;
BOOLEAN ListenThreadShouldRun;
struct ListenContext *l;
/* AcceptSocket's keep a reference on their listen sockets so
* that the Address file will be closed only if there are no
* open connections any more.
*/
struct _WSK_SOCKET_INTERNAL *ListenSocket;
} WSK_SOCKET_INTERNAL, *PWSK_SOCKET_INTERNAL;
struct NetioContext
@ -139,6 +142,9 @@ SocketGet(PWSK_SOCKET_INTERNAL s)
// DbgPrint("SocketGet: refcount is %d socket is %p\n", s->RefCount, s);
}
void
SocketPut(PWSK_SOCKET_INTERNAL s);
static
void SocketShutdown(PWSK_SOCKET_INTERNAL s)
{
@ -146,6 +152,11 @@ void SocketShutdown(PWSK_SOCKET_INTERNAL s)
FUNCTION_TRACE;
if (s->ListenSocket != NULL)
{
SocketPut(s->ListenSocket);
s->ListenSocket = NULL;
}
if (s->ListenThreadHandle != NULL)
{
s->ListenThreadShouldRun = FALSE;
@ -194,11 +205,14 @@ SocketPut(PWSK_SOCKET_INTERNAL s)
{
SocketShutdown(s); /* noop when called twice */
if (s->ConnectionHandle != NULL)
{
ZwClose(s->ConnectionHandle);
s->ConnectionHandle = NULL;
}
/* Especially listen sockets must keep the LocalAddressFile
* open when being closed and there are still accepted sockets
* somewhere. Else the accepted socket I/O will fail with
* c0000184 (invalid device state). So do not close the
* address file in Shutdown only close it here when all AcceptSockets
* are gone (AcceptSockets hold a reference to the listen socket).
*/
if (s->LocalAddressFile != NULL) {
ObDereferenceObject(s->LocalAddressFile);
s->LocalAddressFile = NULL;
@ -329,10 +343,8 @@ static NTSTATUS CreateSocket(
IoSetCompletionRoutine(NewSocketIrp, CompletionFireEvent, &CompletionEvent, TRUE, TRUE, TRUE);
NewSocketIrp->Tail.Overlay.Thread = PsGetCurrentThread();
DbgPrint("into WskSocket ...\n");
Status = WskSocket(NULL, AddressFamily, SocketType, Protocol, Flags,
NULL, NULL, NULL, NULL, NULL, NewSocketIrp);
DbgPrint("out of WskSocket ...\n");
if (Status == STATUS_PENDING) {
KeWaitForSingleObject(&CompletionEvent, Executive, KernelMode, FALSE, NULL);
@ -390,8 +402,6 @@ ListenComplete(PDEVICE_OBJECT DeviceObject, PIRP Irp, PVOID Context)
memcpy(&AcceptSocket->RemoteAddress, RemoteAddress, sizeof(AcceptSocket->RemoteAddress));
}
/* HACK!! We have a backpointer for use of TdiAccept() in RequeueListenThread() do not free l */
ListenSocket->l = l;
/* And wait for the next incoming connection. */
/* This is done in a separate thread at IRQL = 0 */
QueueListening(ListenSocket);
@ -401,7 +411,7 @@ ListenComplete(PDEVICE_OBJECT DeviceObject, PIRP Irp, PVOID Context)
SocketPut(AcceptSocket);
SocketPut(ListenSocket);
// TODO: free ReturnConnectionInfo, RequestConnectionInfo
// ExFreePoolWithTag(l, TAG_NETIO);
ExFreePoolWithTag(l, TAG_NETIO);
return STATUS_SUCCESS;
}
@ -413,7 +423,6 @@ StartListening(PWSK_SOCKET_INTERNAL ListenSocket)
NTSTATUS status;
struct ListenContext *lc;
PWSK_SOCKET_INTERNAL AcceptSocket;
// PIRP AcceptIrp;
FUNCTION_TRACE;
@ -423,14 +432,15 @@ StartListening(PWSK_SOCKET_INTERNAL ListenSocket)
return STATUS_INVALID_PARAMETER;
}
DbgPrint("into CreateSocket ...\n");
status = CreateSocket(ListenSocket->family, ListenSocket->type, ListenSocket->proto, WSK_FLAG_CONNECTION_SOCKET, &AcceptSocket);
DbgPrint("out of CreateSocket ...\n");
if (status != STATUS_SUCCESS)
{
DbgPrint("Could not create AcceptSocket, status is 0x%08x\n", status);
return status;
}
AcceptSocket->ListenSocket = ListenSocket;
/* Put when the AcceptSocket is closed */
SocketGet(AcceptSocket->ListenSocket);
status = STATUS_INSUFFICIENT_RESOURCES;
lc = ExAllocatePoolWithTag(NonPagedPool, sizeof(*lc), TAG_NETIO);
@ -479,15 +489,6 @@ DbgPrint("out of CreateSocket ...\n");
}
ListenSocket->ListenIrp = tdiIrp;
#if 0
AcceptIrp = NULL;
status = TdiAccept(&AcceptIrp, AcceptSocket->ConnectionFile, lc->RequestConnectionInfo, lc->ReturnConnectionInfo, AcceptComplete, AcceptSocket);
if (!NT_SUCCESS(status))
{
DbgPrint("TdiAccept returned non-successful status 0x%08x\n", status);
}
#endif
return STATUS_PENDING;
err_out_free_lc_and_req_conn_info:
@ -536,26 +537,6 @@ static void WSKAPI RequeueListenThread(void *p)
{
break;
}
#if 0
/* We need to call TdiAccept here at IRQL == 0 */
if (ListenSocket->l != NULL)
{
AcceptIrp = NULL;
AcceptSocket = ListenSocket->l->AcceptSocket;
/* HACK: ReturnConnectionInfo already filled out with the remote address */
status = TdiAccept(&AcceptIrp, AcceptSocket->ConnectionFile, ListenSocket->l->ReturnConnectionInfo, NULL, AcceptComplete, AcceptSocket);
if (!NT_SUCCESS(status))
{
DbgPrint("TdiAccept returned non-successful status 0x%08x\n", status);
}
ListenSocket->l = NULL;
}
#endif
/* From here, ListenSocket->l may become invalid ... */
StartListening(ListenSocket);
}
}
@ -1267,6 +1248,7 @@ WskSocket(
s->ListenThreadHandle = NULL;
s->ListenThreadShouldRun = FALSE;
s->ListenCancelled = FALSE;
s->ListenSocket = NULL;
memset(&s->LocalAddress, 0, sizeof(s->LocalAddress));
memset(&s->RemoteAddress, 0, sizeof(s->RemoteAddress));
@ -1282,9 +1264,7 @@ WskSocket(
if (Flags != WSK_FLAG_LISTEN_SOCKET)
{
DbgPrint("into TdiOpenConnectionEndpointFile ...\n");
status = TdiOpenConnectionEndpointFile(&s->TdiName, &s->ConnectionHandle, &s->ConnectionFile);
DbgPrint("out of TdiOpenConnectionEndpointFile ...\n");
if (status != STATUS_SUCCESS)
{
DbgPrint("Could not open TDI handle, status is %x\n", status);
@ -1296,7 +1276,6 @@ DbgPrint("out of TdiOpenConnectionEndpointFile ...\n");
{
KeInitializeEvent(&s->StartListenEvent, SynchronizationEvent, FALSE);
s->ListenThreadShouldRun = TRUE;
s->l = NULL;
status = PsCreateSystemThread(&s->ListenThreadHandle, THREAD_ALL_ACCESS, NULL, NULL, NULL, RequeueListenThread, s);
if (status != STATUS_SUCCESS)