- Fix multiple LPC race conditions.

- Improve LpcpFreeToPortZone calls for optimizing lock release.
- Use RtlCopyMemory instead of RtlMoveMemory to optimize data transfer speed.
- Always hold a reference to the connection port associated to the LPC port and properly handle this reference in all the LPC code.
- Hold a reference to the process that mapped a server/client view, and use this field when freeing memory in case we're called out-of-process.
- Fix a lot of list parsing loops and code to handle the case when the list is now empty.
- Validate more fields and data in the code.
- There are still some LPC bugs at system shutdown.

svn path=/trunk/; revision=25557
This commit is contained in:
Alex Ionescu 2007-01-21 17:21:42 +00:00
parent b4483383d1
commit 28df784f1a
10 changed files with 353 additions and 165 deletions

View file

@ -45,7 +45,7 @@
{ \ { \
/* It's still signaled, so wait on it */ \ /* It's still signaled, so wait on it */ \
KeWaitForSingleObject(s, \ KeWaitForSingleObject(s, \
Executive, \ WrExecutive, \
KernelMode, \ KernelMode, \
FALSE, \ FALSE, \
NULL); \ NULL); \
@ -73,7 +73,7 @@
{ \ { \
/* It's still signaled, so wait on it */ \ /* It's still signaled, so wait on it */ \
KeWaitForSingleObject(s, \ KeWaitForSingleObject(s, \
Executive, \ WrExecutive, \
KernelMode, \ KernelMode, \
FALSE, \ FALSE, \
NULL); \ NULL); \

View file

@ -137,7 +137,6 @@ KeSignalGateBoostPriority(IN PKGATE Gate)
KIRQL OldIrql; KIRQL OldIrql;
ASSERT_GATE(Gate); ASSERT_GATE(Gate);
ASSERT_IRQL_LESS_OR_EQUAL(DISPATCH_LEVEL); ASSERT_IRQL_LESS_OR_EQUAL(DISPATCH_LEVEL);
ASSERT(FALSE);
/* Start entry loop */ /* Start entry loop */
for (;;) for (;;)

View file

@ -19,6 +19,7 @@ NTAPI
LpcExitThread(IN PETHREAD Thread) LpcExitThread(IN PETHREAD Thread)
{ {
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
ASSERT(Thread == PsGetCurrentThread());
/* Acquire the lock */ /* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock); KeAcquireGuardedMutex(&LpcpLock);
@ -54,7 +55,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message,
PLPCP_CONNECTION_MESSAGE ConnectMessage; PLPCP_CONNECTION_MESSAGE ConnectMessage;
PLPCP_PORT_OBJECT ClientPort = NULL; PLPCP_PORT_OBJECT ClientPort = NULL;
PETHREAD Thread = NULL; PETHREAD Thread = NULL;
BOOLEAN LockHeld = Flags & 1; BOOLEAN LockHeld = Flags & 1, ReleaseLock = Flags & 2;
PAGED_CODE(); PAGED_CODE();
LPCTRACE(LPC_CLOSE_DEBUG, "Message: %p. Flags: %lx\n", Message, Flags); LPCTRACE(LPC_CLOSE_DEBUG, "Message: %p. Flags: %lx\n", Message, Flags);
@ -99,7 +100,7 @@ LpcpFreeToPortZone(IN PLPCP_MESSAGE Message,
ExFreeToPagedLookasideList(&LpcpMessagesLookaside, Message); ExFreeToPagedLookasideList(&LpcpMessagesLookaside, Message);
/* Reacquire the lock if needed */ /* Reacquire the lock if needed */
if ((LockHeld) && !(Flags & 2)) KeAcquireGuardedMutex(&LpcpLock); if ((LockHeld) && !(ReleaseLock)) KeAcquireGuardedMutex(&LpcpLock);
} }
VOID VOID
@ -110,14 +111,27 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
PLIST_ENTRY ListHead, NextEntry; PLIST_ENTRY ListHead, NextEntry;
PETHREAD Thread; PETHREAD Thread;
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
PLPCP_PORT_OBJECT ConnectionPort = NULL;
PLPCP_CONNECTION_MESSAGE ConnectMessage; PLPCP_CONNECTION_MESSAGE ConnectMessage;
PAGED_CODE();
LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags); LPCTRACE(LPC_CLOSE_DEBUG, "Port: %p. Flags: %lx\n", Port, Port->Flags);
/* Hold the lock */ /* Hold the lock */
KeAcquireGuardedMutex(&LpcpLock); KeAcquireGuardedMutex(&LpcpLock);
/* Disconnect the port to which this port is connected */ /* Check if we have a connected port */
if (Port->ConnectedPort) Port->ConnectedPort->ConnectedPort = NULL; if (((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_UNCONNECTED_PORT) &&
(Port->ConnectedPort))
{
/* Disconnect it */
Port->ConnectedPort->ConnectedPort = NULL;
if (Port->ConnectedPort->ConnectionPort)
{
/* Save and clear connection port */
ConnectionPort = Port->ConnectedPort->ConnectionPort;
Port->ConnectedPort->ConnectionPort = NULL;
}
}
/* Check if this is a connection port */ /* Check if this is a connection port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT)
@ -129,7 +143,7 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
/* Walk all the threads waiting and signal them */ /* Walk all the threads waiting and signal them */
ListHead = &Port->LpcReplyChainHead; ListHead = &Port->LpcReplyChainHead;
NextEntry = ListHead->Flink; NextEntry = ListHead->Flink;
while (NextEntry != ListHead) while ((NextEntry) && (NextEntry != ListHead))
{ {
/* Get the Thread */ /* Get the Thread */
Thread = CONTAINING_RECORD(NextEntry, ETHREAD, LpcReplyChain); Thread = CONTAINING_RECORD(NextEntry, ETHREAD, LpcReplyChain);
@ -147,58 +161,64 @@ LpcpDestroyPortQueue(IN PLPCP_PORT_OBJECT Port,
/* Check if someone is waiting */ /* Check if someone is waiting */
if (!KeReadStateSemaphore(&Thread->LpcReplySemaphore)) if (!KeReadStateSemaphore(&Thread->LpcReplySemaphore))
{ {
/* Get the message and check if it's a connection request */ /* Get the message */
Message = Thread->LpcReplyMessage; Message = Thread->LpcReplyMessage;
if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST) if (Message)
{ {
/* Get the connection message */ /* Check if it's a connection request */
ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); if (Message->Request.u2.s2.Type == LPC_CONNECTION_REQUEST)
/* Check if it had a section */
if (ConnectMessage->SectionToMap)
{ {
/* Dereference it */ /* Get the connection message */
ObDereferenceObject(ConnectMessage->SectionToMap); ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
/* Check if it had a section */
if (ConnectMessage->SectionToMap)
{
/* Dereference it */
ObDereferenceObject(ConnectMessage->SectionToMap);
}
} }
/* Clear the reply message */
Thread->LpcReplyMessage = NULL;
/* And remove the message from the port zone */
LpcpFreeToPortZone(Message, 1);
NextEntry = Port->LpcReplyChainHead.Flink;
} }
/* Clear the reply message */ /* Release the semaphore and reset message id count */
Thread->LpcReplyMessage = NULL; Thread->LpcReplyMessageId = 0;
KeReleaseSemaphore(&Thread->LpcReplySemaphore, 0, 1, FALSE);
/* And remove the message from the port zone */
LpcpFreeToPortZone(Message, TRUE);
} }
/* Release the semaphore and reset message id count */
Thread->LpcReplyMessageId = 0;
LpcpCompleteWait(&Thread->LpcReplySemaphore);
} }
/* Reinitialize the list head */ /* Reinitialize the list head */
InitializeListHead(&Port->LpcReplyChainHead); InitializeListHead(&Port->LpcReplyChainHead);
/* Loop queued messages */ /* Loop queued messages */
ListHead = &Port->MsgQueue.ReceiveHead; while ((Port->MsgQueue.ReceiveHead.Flink) &&
NextEntry = ListHead->Flink; !(IsListEmpty (&Port->MsgQueue.ReceiveHead)))
while ((NextEntry) && (ListHead != NextEntry))
{ {
/* Get the message */ /* Get the message */
Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry); Message = CONTAINING_RECORD(Port->MsgQueue.ReceiveHead.Flink,
NextEntry = NextEntry->Flink; LPCP_MESSAGE,
Entry);
/* Free and reinitialize it's list head */ /* Free and reinitialize it's list head */
RemoveEntryList(&Message->Entry);
InitializeListHead(&Message->Entry); InitializeListHead(&Message->Entry);
/* Remove it from the port zone */ /* Remove it from the port zone */
LpcpFreeToPortZone(Message, TRUE); LpcpFreeToPortZone(Message, 1);
} }
/* Reinitialize the message queue list head */
InitializeListHead(&Port->MsgQueue.ReceiveHead);
/* Release the lock */ /* Release the lock */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
/* Dereference the connection port */
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
/* Check if we have to free the port entirely */ /* Check if we have to free the port entirely */
if (Destroy) if (Destroy)
{ {
@ -334,13 +354,32 @@ LpcpDeletePort(IN PVOID ObjectBody)
/* Destroy the port queue */ /* Destroy the port queue */
LpcpDestroyPortQueue(Port, TRUE); LpcpDestroyPortQueue(Port, TRUE);
/* Check if we had a client view */ /* Check if we had views */
if (Port->ClientSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(), if ((Port->ClientSectionBase) || (Port->ServerSectionBase))
Port->ClientSectionBase); {
/* Check if we had a client view */
if (Port->ClientSectionBase)
{
/* Unmap it */
MmUnmapViewOfSection(Port->MappingProcess,
Port->ClientSectionBase);
}
/* Check for a server view */ /* Check for a server view */
if (Port->ServerSectionBase) MmUnmapViewOfSection(PsGetCurrentProcess(), if (Port->ServerSectionBase)
Port->ServerSectionBase); {
/* Unmap it */
MmUnmapViewOfSection(Port->MappingProcess,
Port->ServerSectionBase);
}
/* Dereference the mapping process */
ObDereferenceObject(Port->MappingProcess);
Port->MappingProcess = NULL;
}
/* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock);
/* Get the connection port */ /* Get the connection port */
ConnectionPort = Port->ConnectionPort; ConnectionPort = Port->ConnectionPort;
@ -349,9 +388,6 @@ LpcpDeletePort(IN PVOID ObjectBody)
/* Get the PID */ /* Get the PID */
Pid = PsGetCurrentProcessId(); Pid = PsGetCurrentProcessId();
/* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock);
/* Loop the data lists */ /* Loop the data lists */
ListHead = &ConnectionPort->LpcDataInfoChainHead; ListHead = &ConnectionPort->LpcDataInfoChainHead;
NextEntry = ListHead->Flink; NextEntry = ListHead->Flink;
@ -361,12 +397,29 @@ LpcpDeletePort(IN PVOID ObjectBody)
Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry); Message = CONTAINING_RECORD(NextEntry, LPCP_MESSAGE, Entry);
NextEntry = NextEntry->Flink; NextEntry = NextEntry->Flink;
/* Check if the PID matches */ /* Check if this is the connection port */
if (Message->Request.ClientId.UniqueProcess == Pid) if (Port == ConnectionPort)
{
/* Free queued messages */
RemoveEntryList(&Message->Entry);
InitializeListHead(&Message->Entry);
LpcpFreeToPortZone(Message, 1);
/* Restart at the head */
NextEntry = ListHead->Flink;
}
else if ((Message->Request.ClientId.UniqueProcess == Pid) &&
((Message->SenderPort == Port) ||
(Message->SenderPort == Port->ConnectedPort) ||
(Message->SenderPort == ConnectionPort)))
{ {
/* Remove it */ /* Remove it */
RemoveEntryList(&Message->Entry); RemoveEntryList(&Message->Entry);
LpcpFreeToPortZone(Message, TRUE); InitializeListHead(&Message->Entry);
LpcpFreeToPortZone(Message, 1);
/* Restart at the head */
NextEntry = ListHead->Flink;
} }
} }
@ -376,6 +429,11 @@ LpcpDeletePort(IN PVOID ObjectBody)
/* Dereference the object unless it's the same port */ /* Dereference the object unless it's the same port */
if (ConnectionPort != Port) ObDereferenceObject(ConnectionPort); if (ConnectionPort != Port) ObDereferenceObject(ConnectionPort);
} }
else
{
/* Release the lock */
KeReleaseGuardedMutex(&LpcpLock);
}
/* Check if this is a connection port with a server process*/ /* Check if this is a connection port with a server process*/
if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) && if (((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CONNECTION_PORT) &&

View file

@ -145,7 +145,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
Message->Request.ClientId = ReplyMessage->ClientId; Message->Request.ClientId = ReplyMessage->ClientId;
Message->Request.MessageId = ReplyMessage->MessageId; Message->Request.MessageId = ReplyMessage->MessageId;
Message->Request.ClientViewSize = 0; Message->Request.ClientViewSize = 0;
RtlMoveMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength); RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength);
/* At this point, if the caller refused the connection, go to cleanup */ /* At this point, if the caller refused the connection, go to cleanup */
if (!AcceptConnection) goto Cleanup; if (!AcceptConnection) goto Cleanup;
@ -213,6 +213,10 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* Set the view base */ /* Set the view base */
ConnectMessage->ClientView.ViewRemoteBase = ServerPort-> ConnectMessage->ClientView.ViewRemoteBase = ServerPort->
ClientSectionBase; ClientSectionBase;
/* Save and reference the mapping process */
ServerPort->MappingProcess = PsGetCurrentProcess();
ObReferenceObject(ServerPort->MappingProcess);
} }
else else
{ {
@ -351,10 +355,10 @@ NtCompleteConnectPort(IN HANDLE PortHandle)
/* Make sure it has a reply message */ /* Make sure it has a reply message */
if (!Thread->LpcReplyMessage) if (!Thread->LpcReplyMessage)
{ {
/* It doesn't, fail */ /* It doesn't, quit */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
ObDereferenceObject(Port); ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED; return STATUS_SUCCESS;
} }
/* Clear the client thread and wake it up */ /* Clear the client thread and wake it up */

View file

@ -21,6 +21,7 @@ LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
IN PETHREAD CurrentThread) IN PETHREAD CurrentThread)
{ {
PVOID SectionToMap; PVOID SectionToMap;
PLPCP_MESSAGE ReplyMessage;
/* Acquire the LPC lock */ /* Acquire the LPC lock */
KeAcquireGuardedMutex(&LpcpLock); KeAcquireGuardedMutex(&LpcpLock);
@ -34,10 +35,19 @@ LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message,
} }
/* Check if there's a reply message */ /* Check if there's a reply message */
if (CurrentThread->LpcReplyMessage) ReplyMessage = CurrentThread->LpcReplyMessage;
if (ReplyMessage)
{ {
/* Get the message */ /* Get the message */
*Message = CurrentThread->LpcReplyMessage; *Message = ReplyMessage;
/* Check if it's got messages */
if (!IsListEmpty(&ReplyMessage->Entry))
{
/* Clear the list */
RemoveEntryList(&ReplyMessage->Entry);
InitializeListHead(&ReplyMessage->Entry);
}
/* Clear message data */ /* Clear message data */
CurrentThread->LpcReceivedMessageId = 0; CurrentThread->LpcReceivedMessageId = 0;
@ -124,7 +134,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
Status = ObReferenceObjectByName(PortName, Status = ObReferenceObjectByName(PortName,
0, 0,
NULL, NULL,
PORT_ALL_ACCESS, PORT_CONNECT,
LpcPortObjectType, LpcPortObjectType,
PreviousMode, PreviousMode,
NULL, NULL,
@ -286,6 +296,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
/* Update the base */ /* Update the base */
ClientView->ViewBase = Port->ClientSectionBase; ClientView->ViewBase = Port->ClientSectionBase;
/* Reference and remember the process */
ClientPort->MappingProcess = PsGetCurrentProcess();
ObReferenceObject(ClientPort->MappingProcess);
} }
else else
{ {
@ -321,7 +335,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
Message->Request.ClientViewSize = ClientView->ViewSize; Message->Request.ClientViewSize = ClientView->ViewSize;
/* Copy the client view and clear the server view */ /* Copy the client view and clear the server view */
RtlMoveMemory(&ConnectMessage->ClientView, RtlCopyMemory(&ConnectMessage->ClientView,
ClientView, ClientView,
sizeof(PORT_VIEW)); sizeof(PORT_VIEW));
RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW)); RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
@ -348,7 +362,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (ConnectionInformation) if (ConnectionInformation)
{ {
/* Copy it in */ /* Copy it in */
RtlMoveMemory(ConnectMessage + 1, RtlCopyMemory(ConnectMessage + 1,
ConnectionInformation, ConnectionInformation,
ConnectionInfoLength); ConnectionInfoLength);
} }
@ -360,51 +374,63 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (Port->Flags & LPCP_NAME_DELETED) if (Port->Flags & LPCP_NAME_DELETED)
{ {
/* Fail the request */ /* Fail the request */
KeReleaseGuardedMutex(&LpcpLock);
Status = STATUS_OBJECT_NAME_NOT_FOUND; Status = STATUS_OBJECT_NAME_NOT_FOUND;
goto Cleanup; }
else
{
/* Associate no thread yet */
Message->RepliedToThread = NULL;
/* Generate the Message ID and set it */
Message->Request.MessageId = LpcpNextMessageId++;
if (!LpcpNextMessageId) LpcpNextMessageId = 1;
Thread->LpcReplyMessageId = Message->Request.MessageId;
/* Insert the message into the queue and thread chain */
InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
Thread->LpcReplyMessage = Message;
/* Now we can finally reference the client port and link it*/
ObReferenceObject(ClientPort);
ConnectMessage->ClientPort = ClientPort;
/* Enter a critical region */
KeEnterCriticalRegion();
} }
/* Associate no thread yet */ /* Add another reference to the port */
Message->RepliedToThread = NULL; ObReferenceObject(Port);
/* Generate the Message ID and set it */
Message->Request.MessageId = LpcpNextMessageId++;
if (!LpcpNextMessageId) LpcpNextMessageId = 1;
Thread->LpcReplyMessageId = Message->Request.MessageId;
/* Insert the message into the queue and thread chain */
InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry);
InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain);
Thread->LpcReplyMessage = Message;
/* Now we can finally reference the client port and link it*/
ObReferenceObject(ClientPort);
ConnectMessage->ClientPort = ClientPort;
/* Release the lock */ /* Release the lock */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
LPCTRACE(LPC_CONNECT_DEBUG,
"Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
Message,
ConnectMessage,
Port,
ClientPort,
Status);
/* If this is a waitable port, set the event */ /* Check for success */
if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent, if (NT_SUCCESS(Status))
1, {
FALSE); LPCTRACE(LPC_CONNECT_DEBUG,
"Messages: %p/%p. Ports: %p/%p. Status: %lx\n",
Message,
ConnectMessage,
Port,
ClientPort,
Status);
/* Release the queue semaphore */ /* If this is a waitable port, set the event */
LpcpCompleteWait(Port->MsgQueue.Semaphore); if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent,
1,
FALSE);
/* Now wait for a reply */ /* Release the queue semaphore and leave the critical region */
LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode); LpcpCompleteWait(Port->MsgQueue.Semaphore);
KeLeaveCriticalRegion();
/* Check if our wait ended in success */ /* Now wait for a reply */
if (Status != STATUS_SUCCESS) goto Cleanup; LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
}
/* Check for failure */
if (!NT_SUCCESS(Status)) goto Cleanup;
/* Free the connection message */ /* Free the connection message */
SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
@ -432,7 +458,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
} }
/* Return the connection information */ /* Return the connection information */
RtlMoveMemory(ConnectionInformation, RtlCopyMemory(ConnectionInformation,
ConnectMessage + 1, ConnectMessage + 1,
ConnectionInfoLength ); ConnectionInfoLength );
} }
@ -466,7 +492,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (ClientView) if (ClientView)
{ {
/* Copy it back */ /* Copy it back */
RtlMoveMemory(ClientView, RtlCopyMemory(ClientView,
&ConnectMessage->ClientView, &ConnectMessage->ClientView,
sizeof(PORT_VIEW)); sizeof(PORT_VIEW));
} }
@ -475,7 +501,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (ServerView) if (ServerView)
{ {
/* Copy it back */ /* Copy it back */
RtlMoveMemory(ServerView, RtlCopyMemory(ServerView,
&ConnectMessage->ServerView, &ConnectMessage->ServerView,
sizeof(REMOTE_PORT_VIEW)); sizeof(REMOTE_PORT_VIEW));
} }
@ -486,8 +512,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
/* No connection port, we failed */ /* No connection port, we failed */
if (SectionToMap) ObDereferenceObject(SectionToMap); if (SectionToMap) ObDereferenceObject(SectionToMap);
/* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock);
/* Check if it's because the name got deleted */ /* Check if it's because the name got deleted */
if (Port->Flags & LPCP_NAME_DELETED) if (!(ClientPort->ConnectionPort) ||
(Port->Flags & LPCP_NAME_DELETED))
{ {
/* Set the correct status */ /* Set the correct status */
Status = STATUS_OBJECT_NAME_NOT_FOUND; Status = STATUS_OBJECT_NAME_NOT_FOUND;
@ -498,19 +528,27 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
Status = STATUS_PORT_CONNECTION_REFUSED; Status = STATUS_PORT_CONNECTION_REFUSED;
} }
/* Release the lock */
KeReleaseGuardedMutex(&LpcpLock);
/* Kill the port */ /* Kill the port */
ObDereferenceObject(ClientPort); ObDereferenceObject(ClientPort);
} }
/* Free the message */ /* Free the message */
LpcpFreeToPortZone(Message, FALSE); LpcpFreeToPortZone(Message, 0);
return Status; }
else
{
/* No reply message, fail */
if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort);
Status = STATUS_PORT_CONNECTION_REFUSED;
} }
/* No reply message, fail */ /* Return status */
if (SectionToMap) ObDereferenceObject(SectionToMap); ObDereferenceObject(Port);
ObDereferenceObject(ClientPort); return Status;
return STATUS_PORT_CONNECTION_REFUSED;
Cleanup: Cleanup:
/* We failed, free the message */ /* We failed, free the message */
@ -521,20 +559,21 @@ Cleanup:
{ {
/* Wait on it */ /* Wait on it */
KeWaitForSingleObject(&Thread->LpcReplySemaphore, KeWaitForSingleObject(&Thread->LpcReplySemaphore,
WrExecutive,
KernelMode, KernelMode,
Executive,
FALSE, FALSE,
NULL); NULL);
} }
/* Check if we had a message and free it */ /* Check if we had a message and free it */
if (Message) LpcpFreeToPortZone(Message, FALSE); if (Message) LpcpFreeToPortZone(Message, 0);
/* Dereference other objects */ /* Dereference other objects */
if (SectionToMap) ObDereferenceObject(SectionToMap); if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort); ObDereferenceObject(ClientPort);
/* Return status */ /* Return status */
ObDereferenceObject(Port);
return Status; return Status;
} }

View file

@ -19,6 +19,7 @@ NTAPI
LpcpInitializePortQueue(IN PLPCP_PORT_OBJECT Port) LpcpInitializePortQueue(IN PLPCP_PORT_OBJECT Port)
{ {
PLPCP_NONPAGED_PORT_QUEUE MessageQueue; PLPCP_NONPAGED_PORT_QUEUE MessageQueue;
PAGED_CODE();
/* Allocate the queue */ /* Allocate the queue */
MessageQueue = ExAllocatePoolWithTag(NonPagedPool, MessageQueue = ExAllocatePoolWithTag(NonPagedPool,
@ -48,6 +49,7 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
NTSTATUS Status; NTSTATUS Status;
PLPCP_PORT_OBJECT Port; PLPCP_PORT_OBJECT Port;
PAGED_CODE();
LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName); LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName);
/* Create the Object */ /* Create the Object */

View file

@ -22,30 +22,30 @@ NTAPI
NtListenPort(IN HANDLE PortHandle, NtListenPort(IN HANDLE PortHandle,
OUT PPORT_MESSAGE ConnectMessage) OUT PPORT_MESSAGE ConnectMessage)
{ {
NTSTATUS Status; NTSTATUS Status;
PAGED_CODE(); PAGED_CODE();
LPCTRACE(LPC_LISTEN_DEBUG, "Handle: %lx\n", PortHandle); LPCTRACE(LPC_LISTEN_DEBUG, "Handle: %lx\n", PortHandle);
/* Wait forever for a connection request. */ /* Wait forever for a connection request. */
for (;;) for (;;)
{
/* Do the wait */
Status = NtReplyWaitReceivePort(PortHandle,
NULL,
NULL,
ConnectMessage);
/* Accept only LPC_CONNECTION_REQUEST requests. */
if ((Status != STATUS_SUCCESS) ||
(LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
{ {
/* Do the wait */ /* Break out */
Status = NtReplyWaitReceivePort(PortHandle, break;
NULL,
NULL,
ConnectMessage);
/* Accept only LPC_CONNECTION_REQUEST requests. */
if ((Status != STATUS_SUCCESS) ||
(LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
{
/* Break out */
break;
}
} }
}
/* Return status */ /* Return status */
return Status; return Status;
} }

View file

@ -18,7 +18,7 @@ POBJECT_TYPE LpcPortObjectType;
ULONG LpcpMaxMessageSize; ULONG LpcpMaxMessageSize;
PAGED_LOOKASIDE_LIST LpcpMessagesLookaside; PAGED_LOOKASIDE_LIST LpcpMessagesLookaside;
KGUARDED_MUTEX LpcpLock; KGUARDED_MUTEX LpcpLock;
ULONG LpcpTraceLevel = 0; ULONG LpcpTraceLevel = LPC_CLOSE_DEBUG;
ULONG LpcpNextMessageId = 1, LpcpNextCallbackId = 1; ULONG LpcpNextMessageId = 1, LpcpNextCallbackId = 1;
static GENERIC_MAPPING LpcpPortMapping = static GENERIC_MAPPING LpcpPortMapping =
@ -54,7 +54,6 @@ LpcpInitSystem(VOID)
ObjectTypeInitializer.CloseProcedure = LpcpClosePort; ObjectTypeInitializer.CloseProcedure = LpcpClosePort;
ObjectTypeInitializer.DeleteProcedure = LpcpDeletePort; ObjectTypeInitializer.DeleteProcedure = LpcpDeletePort;
ObjectTypeInitializer.ValidAccessMask = PORT_ALL_ACCESS; ObjectTypeInitializer.ValidAccessMask = PORT_ALL_ACCESS;
ObjectTypeInitializer.MaintainTypeList = TRUE;
ObCreateObjectType(&Name, ObCreateObjectType(&Name,
&ObjectTypeInitializer, &ObjectTypeInitializer,
NULL, NULL,

View file

@ -18,7 +18,8 @@ VOID
NTAPI NTAPI
LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port, LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
IN ULONG MessageId, IN ULONG MessageId,
IN ULONG CallbackId) IN ULONG CallbackId,
IN CLIENT_ID ClientId)
{ {
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
PLIST_ENTRY ListHead, NextEntry; PLIST_ENTRY ListHead, NextEntry;
@ -28,6 +29,7 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
{ {
/* Use it */ /* Use it */
Port = Port->ConnectionPort; Port = Port->ConnectionPort;
if (!Port) return;
} }
/* Loop the list */ /* Loop the list */
@ -40,12 +42,13 @@ LpcpFreeDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
/* Make sure it matches */ /* Make sure it matches */
if ((Message->Request.MessageId == MessageId) && if ((Message->Request.MessageId == MessageId) &&
(Message->Request.CallbackId == CallbackId)) (Message->Request.ClientId.UniqueThread == ClientId.UniqueThread) &&
(Message->Request.ClientId.UniqueProcess == ClientId.UniqueProcess))
{ {
/* Unlink and free it */ /* Unlink and free it */
RemoveEntryList(&Message->Entry); RemoveEntryList(&Message->Entry);
InitializeListHead(&Message->Entry); InitializeListHead(&Message->Entry);
LpcpFreeToPortZone(Message, TRUE); LpcpFreeToPortZone(Message, 1);
break; break;
} }
@ -58,25 +61,31 @@ VOID
NTAPI NTAPI
LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port, LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port,
IN PLPCP_MESSAGE Message, IN PLPCP_MESSAGE Message,
IN ULONG LockFlags) IN ULONG LockHeld)
{ {
PAGED_CODE(); PAGED_CODE();
/* Acquire the lock */ /* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock); if (!LockHeld) KeAcquireGuardedMutex(&LpcpLock);
/* Check if the port we want is the connection port */ /* Check if the port we want is the connection port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT) if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT)
{ {
/* Use it */ /* Use it */
Port = Port->ConnectionPort; Port = Port->ConnectionPort;
if (!Port)
{
/* Release the lock and return */
if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
return;
}
} }
/* Link the message */ /* Link the message */
InsertTailList(&Port->LpcDataInfoChainHead, &Message->Entry); InsertTailList(&Port->LpcDataInfoChainHead, &Message->Entry);
/* Release the lock */ /* Release the lock */
KeReleaseGuardedMutex(&LpcpLock); if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock);
} }
VOID VOID
@ -119,7 +128,7 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination,
Destination->ClientViewSize = Origin->ClientViewSize; Destination->ClientViewSize = Origin->ClientViewSize;
/* Copy the Message Data */ /* Copy the Message Data */
RtlMoveMemory(Destination + 1, RtlCopyMemory(Destination + 1,
Data, Data,
((Destination->u1.Length & 0xFFFF) + 3) &~3); ((Destination->u1.Length & 0xFFFF) + 3) &~3);
} }
@ -149,7 +158,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
OUT PPORT_MESSAGE ReceiveMessage, OUT PPORT_MESSAGE ReceiveMessage,
IN PLARGE_INTEGER Timeout OPTIONAL) IN PLARGE_INTEGER Timeout OPTIONAL)
{ {
PLPCP_PORT_OBJECT Port, ReceivePort; PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
NTSTATUS Status; NTSTATUS Status;
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
@ -207,8 +216,32 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
/* Check if this is anything but a client port */ /* Check if this is anything but a client port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CLIENT_PORT) if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CLIENT_PORT)
{ {
/* Use the connection port */ /* Check if this is the connection port */
ReceivePort = Port->ConnectionPort; if (Port->ConnectionPort == Port)
{
/* Use this port */
ConnectionPort = ReceivePort = Port;
ObReferenceObject(ConnectionPort);
}
else
{
/* Acquire the lock */
KeAcquireGuardedMutex(&LpcpLock);
/* Get the port */
ConnectionPort = ReceivePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
KeReleaseGuardedMutex(&LpcpLock);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
/* Release lock and reference */
ObReferenceObject(Port);
KeReleaseGuardedMutex(&LpcpLock);
}
} }
else else
{ {
@ -227,6 +260,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
{ {
/* No thread found, fail */ /* No thread found, fail */
ObDereferenceObject(Port); ObDereferenceObject(Port);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return Status; return Status;
} }
@ -235,6 +269,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
if (!Message) if (!Message)
{ {
/* Fail if we couldn't allocate a message */ /* Fail if we couldn't allocate a message */
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(WakeupThread); ObDereferenceObject(WakeupThread);
ObDereferenceObject(Port); ObDereferenceObject(Port);
return STATUS_NO_MEMORY; return STATUS_NO_MEMORY;
@ -244,11 +279,16 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
KeAcquireGuardedMutex(&LpcpLock); KeAcquireGuardedMutex(&LpcpLock);
/* Make sure this is the reply the thread is waiting for */ /* Make sure this is the reply the thread is waiting for */
if (WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId))// ||
#if 0
((WakeupThread->LpcReplyMessage) &&
(LpcpGetMessageType(&((PLPCP_MESSAGE)WakeupThread->
LpcReplyMessage)->Request) != LPC_REQUEST)))
#endif
{ {
/* It isn't, fail */ /* It isn't, fail */
LpcpFreeToPortZone(Message, TRUE); LpcpFreeToPortZone(Message, 3);
KeReleaseGuardedMutex(&LpcpLock); if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(WakeupThread); ObDereferenceObject(WakeupThread);
ObDereferenceObject(Port); ObDereferenceObject(Port);
return STATUS_REPLY_MESSAGE_MISMATCH; return STATUS_REPLY_MESSAGE_MISMATCH;
@ -261,11 +301,6 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
LPC_REPLY, LPC_REPLY,
NULL); NULL);
/* Free any data information */
LpcpFreeDataInfoMessage(Port,
ReplyMessage->MessageId,
ReplyMessage->CallbackId);
/* Reference the thread while we use it */ /* Reference the thread while we use it */
ObReferenceObject(WakeupThread); ObReferenceObject(WakeupThread);
Message->RepliedToThread = WakeupThread; Message->RepliedToThread = WakeupThread;
@ -292,6 +327,12 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
Thread->LpcReceivedMsgIdValid = FALSE; Thread->LpcReceivedMsgIdValid = FALSE;
} }
/* Free any data information */
LpcpFreeDataInfoMessage(Port,
ReplyMessage->MessageId,
ReplyMessage->CallbackId,
ReplyMessage->ClientId);
/* Release the lock and release the LPC semaphore to wake up waiters */ /* Release the lock and release the LPC semaphore to wake up waiters */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
LpcpCompleteWait(&WakeupThread->LpcReplySemaphore); LpcpCompleteWait(&WakeupThread->LpcReplySemaphore);
@ -319,6 +360,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
/* Release the lock and fail */ /* Release the lock and fail */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(Port); ObDereferenceObject(Port);
return STATUS_UNSUCCESSFUL; return STATUS_UNSUCCESSFUL;
} }
@ -347,9 +389,6 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
Thread->LpcReceivedMessageId = Message->Request.MessageId; Thread->LpcReceivedMessageId = Message->Request.MessageId;
Thread->LpcReceivedMsgIdValid = TRUE; Thread->LpcReceivedMsgIdValid = TRUE;
/* Done touching global data, release the lock */
KeReleaseGuardedMutex(&LpcpLock);
/* Check if this was a connection request */ /* Check if this was a connection request */
if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST) if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
{ {
@ -374,7 +413,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
ReceiveMessage->u1.s1.TotalLength = sizeof(LPCP_MESSAGE) + ReceiveMessage->u1.s1.TotalLength = sizeof(LPCP_MESSAGE) +
ConnectionInfoLength; ConnectionInfoLength;
ReceiveMessage->u1.s1.DataLength = ConnectionInfoLength; ReceiveMessage->u1.s1.DataLength = ConnectionInfoLength;
RtlMoveMemory(ReceiveMessage + 1, RtlCopyMemory(ReceiveMessage + 1,
ConnectMessage + 1, ConnectMessage + 1,
ConnectionInfoLength); ConnectionInfoLength);
@ -413,8 +452,17 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
ASSERT(FALSE); ASSERT(FALSE);
} }
/* If we have a message pointer here, free it */ /* Check if we have a message pointer here */
if (Message) LpcpFreeToPortZone(Message, FALSE); if (Message)
{
/* Free it and release the lock */
LpcpFreeToPortZone(Message, 3);
}
else
{
/* Just release the lock */
KeReleaseGuardedMutex(&LpcpLock);
}
Cleanup: Cleanup:
/* All done, dereference the port and return the status */ /* All done, dereference the port and return the status */
@ -422,6 +470,7 @@ Cleanup:
"Port: %p. Status: %p\n", "Port: %p. Status: %p\n",
Port, Port,
Status); Status);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(Port); ObDereferenceObject(Port);
return Status; return Status;
} }

View file

@ -22,7 +22,7 @@ NTAPI
LpcRequestPort(IN PVOID PortObject, LpcRequestPort(IN PVOID PortObject,
IN PPORT_MESSAGE LpcMessage) IN PPORT_MESSAGE LpcMessage)
{ {
PLPCP_PORT_OBJECT Port = (PLPCP_PORT_OBJECT)PortObject, QueuePort; PLPCP_PORT_OBJECT Port = PortObject, QueuePort, ConnectionPort = NULL;
ULONG MessageType; ULONG MessageType;
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
@ -42,7 +42,7 @@ LpcRequestPort(IN PVOID PortObject,
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
/* Mark this as a kernel-mode message only if we really came from there */ /* Mark this as a kernel-mode message only if we really came from it */
if ((PreviousMode == KernelMode) && if ((PreviousMode == KernelMode) &&
(LpcMessage->u2.s2.Type & LPC_KERNELMODE_MESSAGE)) (LpcMessage->u2.s2.Type & LPC_KERNELMODE_MESSAGE))
{ {
@ -72,6 +72,7 @@ LpcRequestPort(IN PVOID PortObject,
if (!Message) return STATUS_NO_MEMORY; if (!Message) return STATUS_NO_MEMORY;
/* Clear the context */ /* Clear the context */
Message->RepliedToThread = NULL;
Message->PortContext = NULL; Message->PortContext = NULL;
/* Copy the message */ /* Copy the message */
@ -96,13 +97,28 @@ LpcRequestPort(IN PVOID PortObject,
{ {
/* Then copy the context */ /* Then copy the context */
Message->PortContext = QueuePort->PortContext; Message->PortContext = QueuePort->PortContext;
QueuePort = Port->ConnectionPort; ConnectionPort = QueuePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
LpcpFreeToPortZone(Message, 3);
return STATUS_PORT_DISCONNECTED;
}
} }
else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT) else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT)
{ {
/* Any other kind of port, use the connection port */ /* Any other kind of port, use the connection port */
QueuePort = Port->ConnectionPort; ConnectionPort = QueuePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
LpcpFreeToPortZone(Message, 3);
return STATUS_PORT_DISCONNECTED;
}
} }
/* If we have a connection port, reference it */
if (ConnectionPort) ObReferenceObject(ConnectionPort);
} }
} }
else else
@ -126,6 +142,7 @@ LpcRequestPort(IN PVOID PortObject,
InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry); InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
/* Release the lock and release the semaphore */ /* Release the lock and release the semaphore */
KeEnterCriticalRegion();
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
LpcpCompleteWait(QueuePort->MsgQueue.Semaphore); LpcpCompleteWait(QueuePort->MsgQueue.Semaphore);
@ -137,13 +154,15 @@ LpcRequestPort(IN PVOID PortObject,
} }
/* We're done */ /* We're done */
KeLeaveCriticalRegion();
LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message); LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
/* If we got here, then free the message and fail */ /* If we got here, then free the message and fail */
LpcpFreeToPortZone(Message, TRUE); LpcpFreeToPortZone(Message, 3);
KeReleaseGuardedMutex(&LpcpLock); if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return STATUS_PORT_DISCONNECTED; return STATUS_PORT_DISCONNECTED;
} }
@ -181,7 +200,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
IN PPORT_MESSAGE LpcRequest, IN PPORT_MESSAGE LpcRequest,
IN OUT PPORT_MESSAGE LpcReply) IN OUT PPORT_MESSAGE LpcReply)
{ {
PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort; PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
NTSTATUS Status; NTSTATUS Status;
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
@ -286,8 +305,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
if (!QueuePort) if (!QueuePort)
{ {
/* We have no connected port, fail */ /* We have no connected port, fail */
LpcpFreeToPortZone(Message, TRUE); LpcpFreeToPortZone(Message, 3);
KeReleaseGuardedMutex(&LpcpLock);
ObDereferenceObject(Port); ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED; return STATUS_PORT_DISCONNECTED;
} }
@ -299,15 +317,32 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT) if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
{ {
/* Copy the port context and use the connection port */ /* Copy the port context and use the connection port */
Message->PortContext = ReplyPort->PortContext; Message->PortContext = QueuePort->PortContext;
QueuePort = Port->ConnectionPort; ConnectionPort = QueuePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
LpcpFreeToPortZone(Message, 3);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
} }
else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
LPCP_COMMUNICATION_PORT) LPCP_COMMUNICATION_PORT)
{ {
/* Use the connection port for anything but communication ports */ /* Use the connection port for anything but communication ports */
QueuePort = Port->ConnectionPort; ConnectionPort = QueuePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
LpcpFreeToPortZone(Message, 3);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
} }
/* Reference the connection port if it exists */
if (ConnectionPort) ObReferenceObject(ConnectionPort);
} }
else else
{ {
@ -317,6 +352,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
/* No reply thread */ /* No reply thread */
Message->RepliedToThread = NULL; Message->RepliedToThread = NULL;
Message->SenderPort = Port;
/* Generate the Message ID and set it */ /* Generate the Message ID and set it */
Message->Request.MessageId = LpcpNextMessageId++; Message->Request.MessageId = LpcpNextMessageId++;
@ -330,8 +366,10 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
/* Insert the message in our chain */ /* Insert the message in our chain */
InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry); InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain); InsertTailList(&ReplyPort->LpcReplyChainHead, &Thread->LpcReplyChain);
Thread->LpcWaitingOnPort = Port;
/* Release the lock and get the semaphore we'll use later */ /* Release the lock and get the semaphore we'll use later */
KeEnterCriticalRegion();
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
Semaphore = QueuePort->MsgQueue.Semaphore; Semaphore = QueuePort->MsgQueue.Semaphore;
@ -345,6 +383,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
/* Now release the semaphore */ /* Now release the semaphore */
LpcpCompleteWait(Semaphore); LpcpCompleteWait(Semaphore);
KeLeaveCriticalRegion();
/* And let's wait for the reply */ /* And let's wait for the reply */
LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode); LpcpReplyWait(&Thread->LpcReplySemaphore, PreviousMode);
@ -396,7 +435,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
else else
{ {
/* Otherwise, just free it */ /* Otherwise, just free it */
LpcpFreeToPortZone(Message, FALSE); LpcpFreeToPortZone(Message, 0);
} }
} }
else else
@ -407,10 +446,8 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
} }
else else
{ {
/* The wait failed, free the message while holding the lock */ /* The wait failed, free the message */
KeAcquireGuardedMutex(&LpcpLock); if (Message) LpcpFreeToPortZone(Message, 0);
LpcpFreeToPortZone(Message, TRUE);
KeReleaseGuardedMutex(&LpcpLock);
} }
/* All done */ /* All done */
@ -419,6 +456,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
Port, Port,
Status); Status);
ObDereferenceObject(Port); ObDereferenceObject(Port);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return Status; return Status;
} }