SEH protect NtReplyWaitReceivePortEx and fix one instance of message type checking to correctly account for kernel LPC messages.

svn path=/trunk/; revision=33428
This commit is contained in:
Filip Navara 2008-05-11 09:39:26 +00:00
parent 973109beb4
commit a907b85b1a

View file

@ -160,11 +160,14 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
{
PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
NTSTATUS Status;
NTSTATUS Status = STATUS_SUCCESS;
PLPCP_MESSAGE Message;
PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
PLPCP_CONNECTION_MESSAGE ConnectMessage;
ULONG ConnectionInfoLength;
PORT_MESSAGE CapturedReplyMessage;
LARGE_INTEGER CapturedTimeout;
PAGED_CODE();
LPCTRACE(LPC_REPLY_DEBUG,
"Handle: %lx. Messages: %p/%p. Context: %p\n",
@ -173,8 +176,42 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
ReceiveMessage,
PortContext);
/* If this is a system thread, then let it page out its stack */
if (Thread->SystemThread) WaitMode = UserMode;
if (KeGetPreviousMode() == UserMode)
{
_SEH_TRY
{
if (ReplyMessage != NULL)
{
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
ReplyMessage = &CapturedReplyMessage;
}
if (Timeout != NULL)
{
ProbeForReadLargeInteger(Timeout);
RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER));
Timeout = &CapturedTimeout;
}
if (PortContext != NULL)
ProbeForWritePointer(PortContext);
}
_SEH_EXCEPT(_SEH_ExSystemExceptionFilter)
{
Status = _SEH_GetExceptionCode();
}
_SEH_END;
/* Bail out if pointer was invalid */
if (!NT_SUCCESS(Status))
return Status;
}
else
{
/* If this is a system thread, then let it page out its stack */
if (Thread->SystemThread) WaitMode = UserMode;
}
/* Check if caller has a reply message */
if (ReplyMessage)
@ -388,68 +425,76 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
Thread->LpcReceivedMessageId = Message->Request.MessageId;
Thread->LpcReceivedMsgIdValid = TRUE;
/* Check if this was a connection request */
if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
_SEH_TRY
{
/* Get the connection message */
ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
LPCTRACE(LPC_REPLY_DEBUG,
"Request Messages: %p/%p\n",
Message,
ConnectMessage);
/* Get its length */
ConnectionInfoLength = Message->Request.u1.s1.DataLength -
sizeof(LPCP_CONNECTION_MESSAGE);
/* Return it as the receive message */
*ReceiveMessage = Message->Request;
/* Clear our stack variable so the message doesn't get freed */
Message = NULL;
/* Setup the receive message */
ReceiveMessage->u1.s1.TotalLength = (CSHORT)(sizeof(LPCP_MESSAGE) +
ConnectionInfoLength);
ReceiveMessage->u1.s1.DataLength = (CSHORT)ConnectionInfoLength;
RtlCopyMemory(ReceiveMessage + 1,
ConnectMessage + 1,
ConnectionInfoLength);
/* Clear the port context if the caller requested one */
if (PortContext) *PortContext = NULL;
}
else if (Message->Request.u2.s2.Type != LPC_REPLY)
{
/* Otherwise, this is a new message or event */
LPCTRACE(LPC_REPLY_DEBUG,
"Non-Reply Messages: %p/%p\n",
&Message->Request,
(&Message->Request) + 1);
/* Copy it */
LpcpMoveMessage(ReceiveMessage,
&Message->Request,
(&Message->Request) + 1,
0,
NULL);
/* Return its context */
if (PortContext) *PortContext = Message->PortContext;
/* And check if it has data information */
if (Message->Request.u2.s2.DataInfoOffset)
/* Check if this was a connection request */
if (LpcpGetMessageType(&Message->Request) == LPC_CONNECTION_REQUEST)
{
/* It does, save it, and don't free the message below */
LpcpSaveDataInfoMessage(Port, Message, 1);
/* Get the connection message */
ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1);
LPCTRACE(LPC_REPLY_DEBUG,
"Request Messages: %p/%p\n",
Message,
ConnectMessage);
/* Get its length */
ConnectionInfoLength = Message->Request.u1.s1.DataLength -
sizeof(LPCP_CONNECTION_MESSAGE);
/* Return it as the receive message */
*ReceiveMessage = Message->Request;
/* Clear our stack variable so the message doesn't get freed */
Message = NULL;
/* Setup the receive message */
ReceiveMessage->u1.s1.TotalLength = (CSHORT)(sizeof(LPCP_MESSAGE) +
ConnectionInfoLength);
ReceiveMessage->u1.s1.DataLength = (CSHORT)ConnectionInfoLength;
RtlCopyMemory(ReceiveMessage + 1,
ConnectMessage + 1,
ConnectionInfoLength);
/* Clear the port context if the caller requested one */
if (PortContext) *PortContext = NULL;
}
else if (LpcpGetMessageType(&Message->Request) != LPC_REPLY)
{
/* Otherwise, this is a new message or event */
LPCTRACE(LPC_REPLY_DEBUG,
"Non-Reply Messages: %p/%p\n",
&Message->Request,
(&Message->Request) + 1);
/* Copy it */
LpcpMoveMessage(ReceiveMessage,
&Message->Request,
(&Message->Request) + 1,
0,
NULL);
/* Return its context */
if (PortContext) *PortContext = Message->PortContext;
/* And check if it has data information */
if (Message->Request.u2.s2.DataInfoOffset)
{
/* It does, save it, and don't free the message below */
LpcpSaveDataInfoMessage(Port, Message, 1);
Message = NULL;
}
}
else
{
/* This is a reply message, should never happen! */
ASSERT(FALSE);
}
}
else
_SEH_EXCEPT(_SEH_ExSystemExceptionFilter)
{
/* This is a reply message, should never happen! */
ASSERT(FALSE);
Status = _SEH_GetExceptionCode();
}
_SEH_END;
/* Check if we have a message pointer here */
if (Message)