From 28df784f1aa2c1fc390294bfa08b88a1f91228da Mon Sep 17 00:00:00 2001 From: Alex Ionescu Date: Sun, 21 Jan 2007 17:21:42 +0000 Subject: [PATCH] - Fix multiple LPC race conditions. - Improve LpcpFreeToPortZone calls for optimizing lock release. - Use RtlCopyMemory instead of RtlMoveMemory to optimize data transfer speed. - Always hold a reference to the connection port associated to the LPC port and properly handle this reference in all the LPC code. - Hold a reference to the process that mapped a server/client view, and use this field when freeing memory in case we're called out-of-process. - Fix a lot of list parsing loops and code to handle the case when the list is now empty. - Validate more fields and data in the code. - There are still some LPC bugs at system shutdown. svn path=/trunk/; revision=25557 --- reactos/ntoskrnl/include/internal/lpc_x.h | 4 +- reactos/ntoskrnl/ke/gate.c | 1 - reactos/ntoskrnl/lpc/close.c | 146 +++++++++++++++------- reactos/ntoskrnl/lpc/complete.c | 10 +- reactos/ntoskrnl/lpc/connect.c | 143 +++++++++++++-------- reactos/ntoskrnl/lpc/create.c | 2 + reactos/ntoskrnl/lpc/listen.c | 40 +++--- reactos/ntoskrnl/lpc/port.c | 3 +- reactos/ntoskrnl/lpc/reply.c | 97 ++++++++++---- reactos/ntoskrnl/lpc/send.c | 72 ++++++++--- 10 files changed, 353 insertions(+), 165 deletions(-) diff --git a/reactos/ntoskrnl/include/internal/lpc_x.h b/reactos/ntoskrnl/include/internal/lpc_x.h index 0ba3b0aac74..a1cc008d3e7 100644 --- a/reactos/ntoskrnl/include/internal/lpc_x.h +++ b/reactos/ntoskrnl/include/internal/lpc_x.h @@ -45,7 +45,7 @@ { \ /* It's still signaled, so wait on it */ \ KeWaitForSingleObject(s, \ - Executive, \ + WrExecutive, \ KernelMode, \ FALSE, \ NULL); \ @@ -73,7 +73,7 @@ { \ /* It's still signaled, so wait on it */ \ KeWaitForSingleObject(s, \ - Executive, \ + WrExecutive, \ KernelMode, \ FALSE, \ NULL); \ diff --git a/reactos/ntoskrnl/ke/gate.c b/reactos/ntoskrnl/ke/gate.c index f6ace425dd8..3f82d638646 100644 --- a/reactos/ntoskrnl/ke/gate.c +++ b/reactos/ntoskrnl/ke/gate.c @@ -137,7 +137,6 @@ KeSignalGateBoostPriority(IN PKGATE Gate) KIRQL OldIrql; ASSERT_GATE(Gate); ASSERT_IRQL_LESS_OR_EQUAL(DISPATCH_LEVEL); - ASSERT(FALSE); /* Start entry loop */ for (;;) diff --git a/reactos/ntoskrnl/lpc/close.c b/reactos/ntoskrnl/lpc/close.c index bf42824a203..62a0ff70948 100644 --- a/reactos/ntoskrnl/lpc/close.c +++ b/reactos/ntoskrnl/lpc/close.c @@ -19,6 +19,7 @@ NTAPI LpcExitThread(IN PETHREAD Thread) { PLPCP_MESSAGE Message; + ASSERT(Thread == PsGetCurrentThread()); /* Acquire the lock */ KeAcquireGuardedMutex(&LpcpLock); @@ -54,7 +55,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message, PLPCP_CONNECTION_MESSAGE ConnectMessage; PLPCP_PORT_OBJECT ClientPort = NULL; PETHREAD Thread = NULL; - BOOLEAN LockHeld = Flags & 1; + BOOLEAN LockHeld = Flags & 1, ReleaseLock = Flags & 2; PAGED_CODE(); LPCTRACE(LPC_CLOSE_DEBUG, "Message: %p. Flags: %lx\n", Message, Flags); @@ -99,7 +100,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message, ExFreeToPagedLookasideList(&LpcpMessagesLookaside, Message); /* Reacquire the lock if needed */ - if ((LockHeld) && !(Flags & 2)) KeAcquireGuardedMutex(&LpcpLock); + if ((LockHeld) && !(ReleaseLock)) KeAcquireGuardedMutex(&LpcpLock); } VOID @@ -110,14 +111,27 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port, PLIST_ENTRY ListHead, NextEntry; PETHREAD Thread; PLPCP_MESSAGE Message; + PLPCP_PORT_OBJECT ConnectionPort = NULL; PLPCP_CONNECTION_MESSAGE ConnectMessage; + PAGED_CODE(); LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags); /* Hold the lock */ KeAcquireGuardedMutex(&LpcpLock); - /* Disconnect the port to which this port is connected */ - if (Port->ConnectedPort) Port->ConnectedPort->ConnectedPort = NULL; + /* Check if we have a connected port */ + if (((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_UNCONNECTED_PORT) && + (Port->ConnectedPort)) + { + /* Disconnect it */ + Port->ConnectedPort->ConnectedPort = NULL; + if (Port->ConnectedPort->ConnectionPort) + { + /* Save and clear connection port */ + ConnectionPort = Port->ConnectedPort->ConnectionPort; + Port->ConnectedPort->ConnectionPort = NULL; + } + } /* Check if this is a connection port */ if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) @@ -129,7 +143,7 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port, /* Walk all the threads waiting and signal them */ ListHead = &Port->LpcReplyChainHead; NextEntry = ListHead->Flink; - while (NextEntry != ListHead) + while ((NextEntry) && (NextEntry != ListHead)) { /* Get the Thread */ Thread = CONTAINING_RECORD(NextEntry, ETHREAD, LpcReplyChain); @@ -147,58 +161,64 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port, /* Check if someone is waiting */ if (!KeReadStateSemaphore(&Thread->LpcReplySemaphore)) { - /* Get the message and check if it's a connection request */ + /* Get the message */ Message = Thread->LpcReplyMessage; - if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST) + if (Message) { - /* Get the connection message */ - ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); - - /* Check if it had a section */ - if (ConnectMessage->SectionToMap) + /* Check if it's a connection request */ + if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST) { - /* Dereference it */ - ObDereferenceObject(ConnectMessage->SectionToMap); + /* Get the connection message */ + ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); + + /* Check if it had a section */ + if (ConnectMessage->SectionToMap) + { + /* Dereference it */ + ObDereferenceObject(ConnectMessage->SectionToMap); + } } + + /* Clear the reply message */ + Thread->LpcReplyMessage = NULL; + + /* And remove the message from the port zone */ + LpcpFreeToPortZone(Message, 1); + NextEntry = Port->LpcReplyChainHead.Flink; } - /* Clear the reply message */ - Thread->LpcReplyMessage = NULL; - - /* And remove the message from the port zone */ - LpcpFreeToPortZone(Message, TRUE); + /* Release the semaphore and reset message id count */ + Thread->LpcReplyMessageId = 0; + KeReleaseSemaphore(&Thread->LpcReplySemaphore, 0, 1, FALSE); } - - /* Release the semaphore and reset message id count */ - Thread->LpcReplyMessageId = 0; - LpcpCompleteWait(&Thread->LpcReplySemaphore); } /* Reinitialize the list head */ InitializeListHead(&Port->LpcReplyChainHead); /* Loop queued messages */ - ListHead = &Port->MsgQueue.ReceiveHead; - NextEntry = ListHead->Flink; - while ((NextEntry) && (ListHead != NextEntry)) + while ((Port->MsgQueue.ReceiveHead.Flink) && + !(IsListEmpty (&Port->MsgQueue.ReceiveHead))) { /* Get the message */ - Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry); - NextEntry = NextEntry->Flink; + Message = CONTAINING_RECORD(Port->MsgQueue.ReceiveHead.Flink, + LPCP_MESSAGE, + Entry); /* Free and reinitialize it's list head */ + RemoveEntryList(&Message->Entry); InitializeListHead(&Message->Entry); /* Remove it from the port zone */ - LpcpFreeToPortZone(Message, TRUE); + LpcpFreeToPortZone(Message, 1); } - /* Reinitialize the message queue list head */ - InitializeListHead(&Port->MsgQueue.ReceiveHead); - /* Release the lock */ KeReleaseGuardedMutex(&LpcpLock); + /* Dereference the connection port */ + if (ConnectionPort) ObDereferenceObject(ConnectionPort); + /* Check if we have to free the port entirely */ if (Destroy) { @@ -334,13 +354,32 @@ LpcpDeletePort(IN PVOID ObjectBody) /* Destroy the port queue */ LpcpDestroyPortQueue(Port, TRUE); - /* Check if we had a client view */ - if (Port->ClientSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(), - Port->ClientSectionBase); + /* Check if we had views */ + if ((Port->ClientSectionBase) || (Port->ServerSectionBase)) + { + /* Check if we had a client view */ + if (Port->ClientSectionBase) + { + /* Unmap it */ + MmUnmapViewOfSection(Port->MappingProcess, + Port->ClientSectionBase); + } - /* Check for a server view */ - if (Port->ServerSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(), - Port->ServerSectionBase); + /* Check for a server view */ + if (Port->ServerSectionBase) + { + /* Unmap it */ + MmUnmapViewOfSection(Port->MappingProcess, + Port->ServerSectionBase); + } + + /* Dereference the mapping process */ + ObDereferenceObject(Port->MappingProcess); + Port->MappingProcess = NULL; + } + + /* Acquire the lock */ + KeAcquireGuardedMutex(&LpcpLock); /* Get the connection port */ ConnectionPort = Port->ConnectionPort; @@ -349,9 +388,6 @@ LpcpDeletePort(IN PVOID ObjectBody) /* Get the PID */ Pid = PsGetCurrentProcessId(); - /* Acquire the lock */ - KeAcquireGuardedMutex(&LpcpLock); - /* Loop the data lists */ ListHead = &ConnectionPort->LpcDataInfoChainHead; NextEntry = ListHead->Flink; @@ -361,12 +397,29 @@ LpcpDeletePort(IN PVOID ObjectBody) Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry); NextEntry = NextEntry->Flink; - /* Check if the PID matches */ - if (Message->Request.ClientId.UniqueProcess == Pid) + /* Check if this is the connection port */ + if (Port == ConnectionPort) + { + /* Free queued messages */ + RemoveEntryList(&Message->Entry); + InitializeListHead(&Message->Entry); + LpcpFreeToPortZone(Message, 1); + + /* Restart at the head */ + NextEntry = ListHead->Flink; + } + else if ((Message->Request.ClientId.UniqueProcess == Pid) && + ((Message->SenderPort == Port) || + (Message->SenderPort == Port->ConnectedPort) || + (Message->SenderPort == ConnectionPort))) { /* Remove it */ RemoveEntryList(&Message->Entry); - LpcpFreeToPortZone(Message, TRUE); + InitializeListHead(&Message->Entry); + LpcpFreeToPortZone(Message, 1); + + /* Restart at the head */ + NextEntry = ListHead->Flink; } } @@ -376,6 +429,11 @@ LpcpDeletePort(IN PVOID ObjectBody) /* Dereference the object unless it's the same port */ if (ConnectionPort != Port) ObDereferenceObject(ConnectionPort); } + else + { + /* Release the lock */ + KeReleaseGuardedMutex(&LpcpLock); + } /* Check if this is a connection port with a server process*/ if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) && diff --git a/reactos/ntoskrnl/lpc/complete.c b/reactos/ntoskrnl/lpc/complete.c index eeda2e4c412..33c68860c7d 100644 --- a/reactos/ntoskrnl/lpc/complete.c +++ b/reactos/ntoskrnl/lpc/complete.c @@ -145,7 +145,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, Message->Request.ClientId = ReplyMessage->ClientId; Message->Request.MessageId = ReplyMessage->MessageId; Message->Request.ClientViewSize = 0; - RtlMoveMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength); + RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength); /* At this point, if the caller refused the connection, go to cleanup */ if (!AcceptConnection) goto Cleanup; @@ -213,6 +213,10 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle, /* Set the view base */ ConnectMessage->ClientView.ViewRemoteBase = ServerPort-> ClientSectionBase; + + /* Save and reference the mapping process */ + ServerPort->MappingProcess = PsGetCurrentProcess(); + ObReferenceObject(ServerPort->MappingProcess); } else { @@ -351,10 +355,10 @@ NtCompleteConnectPort(IN HANDLE PortHandle) /* Make sure it has a reply message */ if (!Thread->LpcReplyMessage) { - /* It doesn't, fail */ + /* It doesn't, quit */ KeReleaseGuardedMutex(&LpcpLock); ObDereferenceObject(Port); - return STATUS_PORT_DISCONNECTED; + return STATUS_SUCCESS; } /* Clear the client thread and wake it up */ diff --git a/reactos/ntoskrnl/lpc/connect.c b/reactos/ntoskrnl/lpc/connect.c index 73a57ff3219..da0dcde4784 100644 --- a/reactos/ntoskrnl/lpc/connect.c +++ b/reactos/ntoskrnl/lpc/connect.c @@ -21,6 +21,7 @@ LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message, IN PETHREAD CurrentThread) { PVOID SectionToMap; + PLPCP_MESSAGE ReplyMessage; /* Acquire the LPC lock */ KeAcquireGuardedMutex(&LpcpLock); @@ -34,10 +35,19 @@ LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message, } /* Check if there's a reply message */ - if (CurrentThread->LpcReplyMessage) + ReplyMessage = CurrentThread->LpcReplyMessage; + if (ReplyMessage) { /* Get the message */ - *Message = CurrentThread->LpcReplyMessage; + *Message = ReplyMessage; + + /* Check if it's got messages */ + if (!IsListEmpty(&ReplyMessage->Entry)) + { + /* Clear the list */ + RemoveEntryList(&ReplyMessage->Entry); + InitializeListHead(&ReplyMessage->Entry); + } /* Clear message data */ CurrentThread->LpcReceivedMessageId = 0; @@ -124,7 +134,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, Status = ObReferenceObjectByName(PortName, 0, NULL, - PORT_ALL_ACCESS, + PORT_CONNECT, LpcPortObjectType, PreviousMode, NULL, @@ -286,6 +296,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, /* Update the base */ ClientView->ViewBase = Port->ClientSectionBase; + + /* Reference and remember the process */ + ClientPort->MappingProcess = PsGetCurrentProcess(); + ObReferenceObject(ClientPort->MappingProcess); } else { @@ -321,7 +335,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, Message->Request.ClientViewSize = ClientView->ViewSize; /* Copy the client view and clear the server view */ - RtlMoveMemory(&ConnectMessage->ClientView, + RtlCopyMemory(&ConnectMessage->ClientView, ClientView, sizeof(PORT_VIEW)); RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW)); @@ -348,7 +362,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (ConnectionInformation) { /* Copy it in */ - RtlMoveMemory(ConnectMessage + 1, + RtlCopyMemory(ConnectMessage + 1, ConnectionInformation, ConnectionInfoLength); } @@ -360,51 +374,63 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (Port->Flags & LPCP_NAME_DELETED) { /* Fail the request */ - KeReleaseGuardedMutex(&LpcpLock); Status = STATUS_OBJECT_NAME_NOT_FOUND; - goto Cleanup; + } + else + { + /* Associate no thread yet */ + Message->RepliedToThread = NULL; + + /* Generate the Message ID and set it */ + Message->Request.MessageId = LpcpNextMessageId++; + if (!LpcpNextMessageId) LpcpNextMessageId = 1; + Thread->LpcReplyMessageId = Message->Request.MessageId; + + /* Insert the message into the queue and thread chain */ + InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry); + InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain); + Thread->LpcReplyMessage = Message; + + /* Now we can finally reference the client port and link it*/ + ObReferenceObject(ClientPort); + ConnectMessage->ClientPort = ClientPort; + + /* Enter a critical region */ + KeEnterCriticalRegion(); } - /* Associate no thread yet */ - Message->RepliedToThread = NULL; - - /* Generate the Message ID and set it */ - Message->Request.MessageId = LpcpNextMessageId++; - if (!LpcpNextMessageId) LpcpNextMessageId = 1; - Thread->LpcReplyMessageId = Message->Request.MessageId; - - /* Insert the message into the queue and thread chain */ - InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry); - InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain); - Thread->LpcReplyMessage = Message; - - /* Now we can finally reference the client port and link it*/ - ObReferenceObject(ClientPort); - ConnectMessage->ClientPort = ClientPort; + /* Add another reference to the port */ + ObReferenceObject(Port); /* Release the lock */ KeReleaseGuardedMutex(&LpcpLock); - LPCTRACE(LPC_CONNECT_DEBUG, - "Messages: %p/%p. Ports: %p/%p. Status: %lx\n", - Message, - ConnectMessage, - Port, - ClientPort, - Status); - /* If this is a waitable port, set the event */ - if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent, - 1, - FALSE); + /* Check for success */ + if (NT_SUCCESS(Status)) + { + LPCTRACE(LPC_CONNECT_DEBUG, + "Messages: %p/%p. Ports: %p/%p. Status: %lx\n", + Message, + ConnectMessage, + Port, + ClientPort, + Status); - /* Release the queue semaphore */ - LpcpCompleteWait(Port->MsgQueue.Semaphore); + /* If this is a waitable port, set the event */ + if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent, + 1, + FALSE); - /* Now wait for a reply */ - LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode); + /* Release the queue semaphore and leave the critical region */ + LpcpCompleteWait(Port->MsgQueue.Semaphore); + KeLeaveCriticalRegion(); - /* Check if our wait ended in success */ - if (Status != STATUS_SUCCESS) goto Cleanup; + /* Now wait for a reply */ + LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode); + } + + /* Check for failure */ + if (!NT_SUCCESS(Status)) goto Cleanup; /* Free the connection message */ SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); @@ -432,7 +458,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, } /* Return the connection information */ - RtlMoveMemory(ConnectionInformation, + RtlCopyMemory(ConnectionInformation, ConnectMessage + 1, ConnectionInfoLength ); } @@ -466,7 +492,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (ClientView) { /* Copy it back */ - RtlMoveMemory(ClientView, + RtlCopyMemory(ClientView, &ConnectMessage->ClientView, sizeof(PORT_VIEW)); } @@ -475,7 +501,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, if (ServerView) { /* Copy it back */ - RtlMoveMemory(ServerView, + RtlCopyMemory(ServerView, &ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW)); } @@ -486,8 +512,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, /* No connection port, we failed */ if (SectionToMap) ObDereferenceObject(SectionToMap); + /* Acquire the lock */ + KeAcquireGuardedMutex(&LpcpLock); + /* Check if it's because the name got deleted */ - if (Port->Flags & LPCP_NAME_DELETED) + if (!(ClientPort->ConnectionPort) || + (Port->Flags & LPCP_NAME_DELETED)) { /* Set the correct status */ Status = STATUS_OBJECT_NAME_NOT_FOUND; @@ -498,19 +528,27 @@ NtSecureConnectPort(OUT PHANDLE PortHandle, Status = STATUS_PORT_CONNECTION_REFUSED; } + /* Release the lock */ + KeReleaseGuardedMutex(&LpcpLock); + /* Kill the port */ ObDereferenceObject(ClientPort); } /* Free the message */ - LpcpFreeToPortZone(Message, FALSE); - return Status; + LpcpFreeToPortZone(Message, 0); + } + else + { + /* No reply message, fail */ + if (SectionToMap) ObDereferenceObject(SectionToMap); + ObDereferenceObject(ClientPort); + Status = STATUS_PORT_CONNECTION_REFUSED; } - /* No reply message, fail */ - if (SectionToMap) ObDereferenceObject(SectionToMap); - ObDereferenceObject(ClientPort); - return STATUS_PORT_CONNECTION_REFUSED; + /* Return status */ + ObDereferenceObject(Port); + return Status; Cleanup: /* We failed, free the message */ @@ -521,20 +559,21 @@ Cleanup: { /* Wait on it */ KeWaitForSingleObject(&Thread->LpcReplySemaphore, + WrExecutive, KernelMode, - Executive, FALSE, NULL); } /* Check if we had a message and free it */ - if (Message) LpcpFreeToPortZone(Message, FALSE); + if (Message) LpcpFreeToPortZone(Message, 0); /* Dereference other objects */ if (SectionToMap) ObDereferenceObject(SectionToMap); ObDereferenceObject(ClientPort); /* Return status */ + ObDereferenceObject(Port); return Status; } diff --git a/reactos/ntoskrnl/lpc/create.c b/reactos/ntoskrnl/lpc/create.c index a0478d5ed5e..f98df57bf22 100644 --- a/reactos/ntoskrnl/lpc/create.c +++ b/reactos/ntoskrnl/lpc/create.c @@ -19,6 +19,7 @@ NTAPI LpcpInitializePortQueue(IN PLPCP_PORT_OBJECT Port) { PLPCP_NONPAGED_PORT_QUEUE MessageQueue; + PAGED_CODE(); /* Allocate the queue */ MessageQueue = ExAllocatePoolWithTag(NonPagedPool, @@ -48,6 +49,7 @@ LpcpCreatePort(OUT PHANDLE PortHandle, KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); NTSTATUS Status; PLPCP_PORT_OBJECT Port; + PAGED_CODE(); LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName); /* Create the Object */ diff --git a/reactos/ntoskrnl/lpc/listen.c b/reactos/ntoskrnl/lpc/listen.c index 8bc84c324be..1d4e94c416d 100644 --- a/reactos/ntoskrnl/lpc/listen.c +++ b/reactos/ntoskrnl/lpc/listen.c @@ -22,30 +22,30 @@ NTAPI NtListenPort(IN HANDLE PortHandle, OUT PPORT_MESSAGE ConnectMessage) { - NTSTATUS Status; - PAGED_CODE(); - LPCTRACE(LPC_LISTEN_DEBUG, "Handle: %lx\n", PortHandle); + NTSTATUS Status; + PAGED_CODE(); + LPCTRACE(LPC_LISTEN_DEBUG, "Handle: %lx\n", PortHandle); - /* Wait forever for a connection request. */ - for (;;) + /* Wait forever for a connection request. */ + for (;;) + { + /* Do the wait */ + Status = NtReplyWaitReceivePort(PortHandle, + NULL, + NULL, + ConnectMessage); + + /* Accept only LPC_CONNECTION_REQUEST requests. */ + if ((Status != STATUS_SUCCESS) || + (LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST)) { - /* Do the wait */ - Status = NtReplyWaitReceivePort(PortHandle, - NULL, - NULL, - ConnectMessage); - - /* Accept only LPC_CONNECTION_REQUEST requests. */ - if ((Status != STATUS_SUCCESS) || - (LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST)) - { - /* Break out */ - break; - } + /* Break out */ + break; } + } - /* Return status */ - return Status; + /* Return status */ + return Status; } diff --git a/reactos/ntoskrnl/lpc/port.c b/reactos/ntoskrnl/lpc/port.c index b085e9183d3..92bcb4d59aa 100644 --- a/reactos/ntoskrnl/lpc/port.c +++ b/reactos/ntoskrnl/lpc/port.c @@ -18,7 +18,7 @@ POBJECT_TYPE LpcPortObjectType; ULONG LpcpMaxMessageSize; PAGED_LOOKASIDE_LIST LpcpMessagesLookaside; KGUARDED_MUTEX LpcpLock; -ULONG LpcpTraceLevel = 0; +ULONG LpcpTraceLevel = LPC_CLOSE_DEBUG; ULONG LpcpNextMessageId = 1, LpcpNextCallbackId = 1; static GENERIC_MAPPING LpcpPortMapping = @@ -54,7 +54,6 @@ LpcpInitSystem(VOID) ObjectTypeInitializer.CloseProcedure = LpcpClosePort; ObjectTypeInitializer.DeleteProcedure = LpcpDeletePort; ObjectTypeInitializer.ValidAccessMask = PORT_ALL_ACCESS; - ObjectTypeInitializer.MaintainTypeList = TRUE; ObCreateObjectType(&Name, &ObjectTypeInitializer, NULL, diff --git a/reactos/ntoskrnl/lpc/reply.c b/reactos/ntoskrnl/lpc/reply.c index 751564b8666..ac421fbf915 100644 --- a/reactos/ntoskrnl/lpc/reply.c +++ b/reactos/ntoskrnl/lpc/reply.c @@ -18,7 +18,8 @@ VOID NTAPI LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port, IN ULONG MessageId, - IN ULONG CallbackId) + IN ULONG CallbackId, + IN CLIENT_ID ClientId) { PLPCP_MESSAGE Message; PLIST_ENTRY ListHead, NextEntry; @@ -28,6 +29,7 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port, { /* Use it */ Port = Port->ConnectionPort; + if (!Port) return; } /* Loop the list */ @@ -40,12 +42,13 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port, /* Make sure it matches */ if ((Message->Request.MessageId == MessageId) && - (Message->Request.CallbackId == CallbackId)) + (Message->Request.ClientId.UniqueThread == ClientId.UniqueThread) && + (Message->Request.ClientId.UniqueProcess == ClientId.UniqueProcess)) { /* Unlink and free it */ RemoveEntryList(&Message->Entry); InitializeListHead(&Message->Entry); - LpcpFreeToPortZone(Message, TRUE); + LpcpFreeToPortZone(Message, 1); break; } @@ -58,25 +61,31 @@ VOID NTAPI LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port, IN PLPCP_MESSAGE Message, - IN ULONG LockFlags) + IN ULONG LockHeld) { PAGED_CODE(); /* Acquire the lock */ - KeAcquireGuardedMutex(&LpcpLock); + if (!LockHeld) KeAcquireGuardedMutex(&LpcpLock); /* Check if the port we want is the connection port */ if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT) { /* Use it */ Port = Port->ConnectionPort; + if (!Port) + { + /* Release the lock and return */ + if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock); + return; + } } /* Link the message */ InsertTailList(&Port->LpcDataInfoChainHead, &Message->Entry); /* Release the lock */ - KeReleaseGuardedMutex(&LpcpLock); + if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock); } VOID @@ -119,7 +128,7 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination, Destination->ClientViewSize = Origin->ClientViewSize; /* Copy the Message Data */ - RtlMoveMemory(Destination + 1, + RtlCopyMemory(Destination + 1, Data, ((Destination->u1.Length & 0xFFFF) + 3) &~3); } @@ -149,7 +158,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, OUT PPORT_MESSAGE ReceiveMessage, IN PLARGE_INTEGER Timeout OPTIONAL) { - PLPCP_PORT_OBJECT Port, ReceivePort; + PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode; NTSTATUS Status; PLPCP_MESSAGE Message; @@ -207,8 +216,32 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, /* Check if this is anything but a client port */ if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CLIENT_PORT) { - /* Use the connection port */ - ReceivePort = Port->ConnectionPort; + /* Check if this is the connection port */ + if (Port->ConnectionPort == Port) + { + /* Use this port */ + ConnectionPort = ReceivePort = Port; + ObReferenceObject(ConnectionPort); + } + else + { + /* Acquire the lock */ + KeAcquireGuardedMutex(&LpcpLock); + + /* Get the port */ + ConnectionPort = ReceivePort = Port->ConnectionPort; + if (!ConnectionPort) + { + /* Fail */ + KeReleaseGuardedMutex(&LpcpLock); + ObDereferenceObject(Port); + return STATUS_PORT_DISCONNECTED; + } + + /* Release lock and reference */ + ObReferenceObject(Port); + KeReleaseGuardedMutex(&LpcpLock); + } } else { @@ -227,6 +260,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, { /* No thread found, fail */ ObDereferenceObject(Port); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); return Status; } @@ -235,6 +269,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, if (!Message) { /* Fail if we couldn't allocate a message */ + if (ConnectionPort) ObDereferenceObject(ConnectionPort); ObDereferenceObject(WakeupThread); ObDereferenceObject(Port); return STATUS_NO_MEMORY; @@ -244,11 +279,16 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, KeAcquireGuardedMutex(&LpcpLock); /* Make sure this is the reply the thread is waiting for */ - if (WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) + if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId))// || +#if 0 + ((WakeupThread->LpcReplyMessage) && + (LpcpGetMessageType(&((PLPCP_MESSAGE)WakeupThread-> + LpcReplyMessage)->Request) != LPC_REQUEST))) +#endif { /* It isn't, fail */ - LpcpFreeToPortZone(Message, TRUE); - KeReleaseGuardedMutex(&LpcpLock); + LpcpFreeToPortZone(Message, 3); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); ObDereferenceObject(WakeupThread); ObDereferenceObject(Port); return STATUS_REPLY_MESSAGE_MISMATCH; @@ -261,11 +301,6 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, LPC_REPLY, NULL); - /* Free any data information */ - LpcpFreeDataInfoMessage(Port, - ReplyMessage->MessageId, - ReplyMessage->CallbackId); - /* Reference the thread while we use it */ ObReferenceObject(WakeupThread); Message->RepliedToThread = WakeupThread; @@ -292,6 +327,12 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, Thread->LpcReceivedMsgIdValid = FALSE; } + /* Free any data information */ + LpcpFreeDataInfoMessage(Port, + ReplyMessage->MessageId, + ReplyMessage->CallbackId, + ReplyMessage->ClientId); + /* Release the lock and release the LPC semaphore to wake up waiters */ KeReleaseGuardedMutex(&LpcpLock); LpcpCompleteWait(&WakeupThread->LpcReplySemaphore); @@ -319,6 +360,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, /* Release the lock and fail */ KeReleaseGuardedMutex(&LpcpLock); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); ObDereferenceObject(Port); return STATUS_UNSUCCESSFUL; } @@ -347,9 +389,6 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, Thread->LpcReceivedMessageId = Message->Request.MessageId; Thread->LpcReceivedMsgIdValid = TRUE; - /* Done touching global data, release the lock */ - KeReleaseGuardedMutex(&LpcpLock); - /* Check if this was a connection request */ if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST) { @@ -374,7 +413,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, ReceiveMessage->u1.s1.TotalLength = sizeof(LPCP_MESSAGE) + ConnectionInfoLength; ReceiveMessage->u1.s1.DataLength = ConnectionInfoLength; - RtlMoveMemory(ReceiveMessage + 1, + RtlCopyMemory(ReceiveMessage + 1, ConnectMessage + 1, ConnectionInfoLength); @@ -413,8 +452,17 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, ASSERT(FALSE); } - /* If we have a message pointer here, free it */ - if (Message) LpcpFreeToPortZone(Message, FALSE); + /* Check if we have a message pointer here */ + if (Message) + { + /* Free it and release the lock */ + LpcpFreeToPortZone(Message, 3); + } + else + { + /* Just release the lock */ + KeReleaseGuardedMutex(&LpcpLock); + } Cleanup: /* All done, dereference the port and return the status */ @@ -422,6 +470,7 @@ Cleanup: "Port: %p. Status: %p\n", Port, Status); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); ObDereferenceObject(Port); return Status; } diff --git a/reactos/ntoskrnl/lpc/send.c b/reactos/ntoskrnl/lpc/send.c index 3658ced2ca4..525267c1e84 100644 --- a/reactos/ntoskrnl/lpc/send.c +++ b/reactos/ntoskrnl/lpc/send.c @@ -22,7 +22,7 @@ NTAPI LpcRequestPort(IN PVOID PortObject, IN PPORT_MESSAGE LpcMessage) { - PLPCP_PORT_OBJECT Port = (PLPCP_PORT_OBJECT)PortObject, QueuePort; + PLPCP_PORT_OBJECT Port = PortObject, QueuePort, ConnectionPort = NULL; ULONG MessageType; PLPCP_MESSAGE Message; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); @@ -42,7 +42,7 @@ LpcRequestPort(IN PVOID PortObject, return STATUS_INVALID_PARAMETER; } - /* Mark this as a kernel-mode message only if we really came from there */ + /* Mark this as a kernel-mode message only if we really came from it */ if ((PreviousMode == KernelMode) && (LpcMessage->u2.s2.Type & LPC_KERNELMODE_MESSAGE)) { @@ -72,6 +72,7 @@ LpcRequestPort(IN PVOID PortObject, if (!Message) return STATUS_NO_MEMORY; /* Clear the context */ + Message->RepliedToThread = NULL; Message->PortContext = NULL; /* Copy the message */ @@ -96,13 +97,28 @@ LpcRequestPort(IN PVOID PortObject, { /* Then copy the context */ Message->PortContext = QueuePort->PortContext; - QueuePort = Port->ConnectionPort; + ConnectionPort = QueuePort = Port->ConnectionPort; + if (!ConnectionPort) + { + /* Fail */ + LpcpFreeToPortZone(Message, 3); + return STATUS_PORT_DISCONNECTED; + } } else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT) { /* Any other kind of port, use the connection port */ - QueuePort = Port->ConnectionPort; + ConnectionPort = QueuePort = Port->ConnectionPort; + if (!ConnectionPort) + { + /* Fail */ + LpcpFreeToPortZone(Message, 3); + return STATUS_PORT_DISCONNECTED; + } } + + /* If we have a connection port, reference it */ + if (ConnectionPort) ObReferenceObject(ConnectionPort); } } else @@ -126,6 +142,7 @@ LpcRequestPort(IN PVOID PortObject, InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry); /* Release the lock and release the semaphore */ + KeEnterCriticalRegion(); KeReleaseGuardedMutex(&LpcpLock); LpcpCompleteWait(QueuePort->MsgQueue.Semaphore); @@ -137,13 +154,15 @@ LpcRequestPort(IN PVOID PortObject, } /* We're done */ + KeLeaveCriticalRegion(); LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); return STATUS_SUCCESS; } /* If we got here, then free the message and fail */ - LpcpFreeToPortZone(Message, TRUE); - KeReleaseGuardedMutex(&LpcpLock); + LpcpFreeToPortZone(Message, 3); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); return STATUS_PORT_DISCONNECTED; } @@ -181,7 +200,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, IN PPORT_MESSAGE LpcRequest, IN OUT PPORT_MESSAGE LpcReply) { - PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort; + PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); NTSTATUS Status; PLPCP_MESSAGE Message; @@ -286,8 +305,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, if (!QueuePort) { /* We have no connected port, fail */ - LpcpFreeToPortZone(Message, TRUE); - KeReleaseGuardedMutex(&LpcpLock); + LpcpFreeToPortZone(Message, 3); ObDereferenceObject(Port); return STATUS_PORT_DISCONNECTED; } @@ -299,15 +317,32 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT) { /* Copy the port context and use the connection port */ - Message->PortContext = ReplyPort->PortContext; - QueuePort = Port->ConnectionPort; + Message->PortContext = QueuePort->PortContext; + ConnectionPort = QueuePort = Port->ConnectionPort; + if (!ConnectionPort) + { + /* Fail */ + LpcpFreeToPortZone(Message, 3); + ObDereferenceObject(Port); + return STATUS_PORT_DISCONNECTED; + } } else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT) { /* Use the connection port for anything but communication ports */ - QueuePort = Port->ConnectionPort; + ConnectionPort = QueuePort = Port->ConnectionPort; + if (!ConnectionPort) + { + /* Fail */ + LpcpFreeToPortZone(Message, 3); + ObDereferenceObject(Port); + return STATUS_PORT_DISCONNECTED; + } } + + /* Reference the connection port if it exists */ + if (ConnectionPort) ObReferenceObject(ConnectionPort); } else { @@ -317,6 +352,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, /* No reply thread */ Message->RepliedToThread = NULL; + Message->SenderPort = Port; /* Generate the Message ID and set it */ Message->Request.MessageId = LpcpNextMessageId++; @@ -330,8 +366,10 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, /* Insert the message in our chain */ InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry); InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain); + Thread->LpcWaitingOnPort = Port; /* Release the lock and get the semaphore we'll use later */ + KeEnterCriticalRegion(); KeReleaseGuardedMutex(&LpcpLock); Semaphore = QueuePort->MsgQueue.Semaphore; @@ -345,6 +383,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, /* Now release the semaphore */ LpcpCompleteWait(Semaphore); + KeLeaveCriticalRegion(); /* And let's wait for the reply */ LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode); @@ -396,7 +435,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, else { /* Otherwise, just free it */ - LpcpFreeToPortZone(Message, FALSE); + LpcpFreeToPortZone(Message, 0); } } else @@ -407,10 +446,8 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, } else { - /* The wait failed, free the message while holding the lock */ - KeAcquireGuardedMutex(&LpcpLock); - LpcpFreeToPortZone(Message, TRUE); - KeReleaseGuardedMutex(&LpcpLock); + /* The wait failed, free the message */ + if (Message) LpcpFreeToPortZone(Message, 0); } /* All done */ @@ -419,6 +456,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, Port, Status); ObDereferenceObject(Port); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); return Status; }