From a907b85b1a6317019cfd05e78a027bd3ad7c7781 Mon Sep 17 00:00:00 2001 From: Filip Navara Date: Sun, 11 May 2008 09:39:26 +0000 Subject: [PATCH] SEH protect NtReplyWaitReceivePortEx and fix one instance of message type checking to correctly account for kernel LPC messages. svn path=/trunk/; revision=33428 --- reactos/ntoskrnl/lpc/reply.c | 161 ++++++++++++++++++++++------------- 1 file changed, 103 insertions(+), 58 deletions(-) diff --git a/reactos/ntoskrnl/lpc/reply.c b/reactos/ntoskrnl/lpc/reply.c index 55fa1c0d7eb..a0630ac1d8b 100644 --- a/reactos/ntoskrnl/lpc/reply.c +++ b/reactos/ntoskrnl/lpc/reply.c @@ -160,11 +160,14 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, { PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode; - NTSTATUS Status; + NTSTATUS Status = STATUS_SUCCESS; PLPCP_MESSAGE Message; PETHREAD Thread = PsGetCurrentThread(), WakeupThread; PLPCP_CONNECTION_MESSAGE ConnectMessage; ULONG ConnectionInfoLength; + PORT_MESSAGE CapturedReplyMessage; + LARGE_INTEGER CapturedTimeout; + PAGED_CODE(); LPCTRACE(LPC_REPLY_DEBUG, "Handle: %lx. Messages: %p/%p. Context: %p\n", @@ -173,8 +176,42 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, ReceiveMessage, PortContext); - /* If this is a system thread, then let it page out its stack */ - if (Thread->SystemThread) WaitMode = UserMode; + if (KeGetPreviousMode() == UserMode) + { + _SEH_TRY + { + if (ReplyMessage != NULL) + { + ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); + RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE)); + ReplyMessage = &CapturedReplyMessage; + } + + if (Timeout != NULL) + { + ProbeForReadLargeInteger(Timeout); + RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER)); + Timeout = &CapturedTimeout; + } + + if (PortContext != NULL) + ProbeForWritePointer(PortContext); + } + _SEH_EXCEPT(_SEH_ExSystemExceptionFilter) + { + Status = _SEH_GetExceptionCode(); + } + _SEH_END; + + /* Bail out if pointer was invalid */ + if (!NT_SUCCESS(Status)) + return Status; + } + else + { + /* If this is a system thread, then let it page out its stack */ + if (Thread->SystemThread) WaitMode = UserMode; + } /* Check if caller has a reply message */ if (ReplyMessage) @@ -388,68 +425,76 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle, Thread->LpcReceivedMessageId = Message->Request.MessageId; Thread->LpcReceivedMsgIdValid = TRUE; - /* Check if this was a connection request */ - if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST) + _SEH_TRY { - /* Get the connection message */ - ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); - LPCTRACE(LPC_REPLY_DEBUG, - "Request Messages: %p/%p\n", - Message, - ConnectMessage); - - /* Get its length */ - ConnectionInfoLength = Message->Request.u1.s1.DataLength - - sizeof(LPCP_CONNECTION_MESSAGE); - - /* Return it as the receive message */ - *ReceiveMessage = Message->Request; - - /* Clear our stack variable so the message doesn't get freed */ - Message = NULL; - - /* Setup the receive message */ - ReceiveMessage->u1.s1.TotalLength = (CSHORT)(sizeof(LPCP_MESSAGE) + - ConnectionInfoLength); - ReceiveMessage->u1.s1.DataLength = (CSHORT)ConnectionInfoLength; - RtlCopyMemory(ReceiveMessage + 1, - ConnectMessage + 1, - ConnectionInfoLength); - - /* Clear the port context if the caller requested one */ - if (PortContext) *PortContext = NULL; - } - else if (Message->Request.u2.s2.Type != LPC_REPLY) - { - /* Otherwise, this is a new message or event */ - LPCTRACE(LPC_REPLY_DEBUG, - "Non-Reply Messages: %p/%p\n", - &Message->Request, - (&Message->Request) + 1); - - /* Copy it */ - LpcpMoveMessage(ReceiveMessage, - &Message->Request, - (&Message->Request) + 1, - 0, - NULL); - - /* Return its context */ - if (PortContext) *PortContext = Message->PortContext; - - /* And check if it has data information */ - if (Message->Request.u2.s2.DataInfoOffset) + /* Check if this was a connection request */ + if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST) { - /* It does, save it, and don't free the message below */ - LpcpSaveDataInfoMessage(Port, Message, 1); + /* Get the connection message */ + ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); + LPCTRACE(LPC_REPLY_DEBUG, + "Request Messages: %p/%p\n", + Message, + ConnectMessage); + + /* Get its length */ + ConnectionInfoLength = Message->Request.u1.s1.DataLength - + sizeof(LPCP_CONNECTION_MESSAGE); + + /* Return it as the receive message */ + *ReceiveMessage = Message->Request; + + /* Clear our stack variable so the message doesn't get freed */ Message = NULL; + + /* Setup the receive message */ + ReceiveMessage->u1.s1.TotalLength = (CSHORT)(sizeof(LPCP_MESSAGE) + + ConnectionInfoLength); + ReceiveMessage->u1.s1.DataLength = (CSHORT)ConnectionInfoLength; + RtlCopyMemory(ReceiveMessage + 1, + ConnectMessage + 1, + ConnectionInfoLength); + + /* Clear the port context if the caller requested one */ + if (PortContext) *PortContext = NULL; + } + else if (LpcpGetMessageType(&Message->Request) != LPC_REPLY) + { + /* Otherwise, this is a new message or event */ + LPCTRACE(LPC_REPLY_DEBUG, + "Non-Reply Messages: %p/%p\n", + &Message->Request, + (&Message->Request) + 1); + + /* Copy it */ + LpcpMoveMessage(ReceiveMessage, + &Message->Request, + (&Message->Request) + 1, + 0, + NULL); + + /* Return its context */ + if (PortContext) *PortContext = Message->PortContext; + + /* And check if it has data information */ + if (Message->Request.u2.s2.DataInfoOffset) + { + /* It does, save it, and don't free the message below */ + LpcpSaveDataInfoMessage(Port, Message, 1); + Message = NULL; + } + } + else + { + /* This is a reply message, should never happen! */ + ASSERT(FALSE); } } - else + _SEH_EXCEPT(_SEH_ExSystemExceptionFilter) { - /* This is a reply message, should never happen! */ - ASSERT(FALSE); + Status = _SEH_GetExceptionCode(); } + _SEH_END; /* Check if we have a message pointer here */ if (Message)