[NTOS:LPC]

- Fix the usage, or add support (for NtSecureConnectPort based on patches by Alexander Andrejevic, and for NtReplyPort, NtReplyWaitReceivePortEx, NtListenPort by me) for capturing user-mode parameters and using SEH in LPC functions.
CORE-7371 #resolve
- Make NtSecureConnectPort call SeQueryInformationToken, now that the latter is implemented since r73122.
- Fix ObDereferenceObject usage for Port vs. ClientPort in NtSecureConnectPort.
- ObCloseHandle certainly needs to be called with the actual 'PreviousMode'.

svn path=/trunk/; revision=73164
This commit is contained in:
Hermès Bélusca-Maïto 2016-11-07 01:24:24 +00:00
parent a066c5bbb0
commit f197af7358
7 changed files with 388 additions and 215 deletions

View file

@ -46,6 +46,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
{
NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
PORT_VIEW CapturedServerView;
PORT_MESSAGE CapturedReplyMessage;
ULONG ConnectionInfoLength;
PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
PLPCP_CONNECTION_MESSAGE ConnectMessage;
@ -55,8 +57,6 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
PEPROCESS ClientProcess;
PETHREAD ClientThread;
LARGE_INTEGER SectionOffset;
CLIENT_ID ClientId;
ULONG MessageId;
PAGED_CODE();
LPCTRACE(LPC_COMPLETE_DEBUG,
@ -70,18 +70,15 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{
/* Enter SEH for probing the parameters */
_SEH2_TRY
{
/* Probe the PortHandle */
ProbeForWriteHandle(PortHandle);
/* Probe the basic ReplyMessage structure */
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
/* Grab some values */
ClientId = ReplyMessage->ClientId;
MessageId = ReplyMessage->MessageId;
ConnectionInfoLength = ReplyMessage->u1.s1.DataLength;
ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
/* Probe the connection info */
ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1);
@ -89,10 +86,11 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* The following parameters are optional */
if (ServerView != NULL)
{
ProbeForWrite(ServerView, sizeof(PORT_VIEW), sizeof(ULONG));
ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
CapturedServerView = *(volatile PORT_VIEW*)ServerView;
/* Validate the size of the server view */
if (ServerView->Length != sizeof(PORT_VIEW))
if (CapturedServerView.Length != sizeof(CapturedServerView))
{
/* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER);
@ -101,10 +99,10 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
if (ClientView != NULL)
{
ProbeForWrite(ClientView, sizeof(REMOTE_PORT_VIEW), sizeof(ULONG));
ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
/* Validate the size of the client view */
if (ClientView->Length != sizeof(REMOTE_PORT_VIEW))
if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView))
{
/* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER);
@ -120,20 +118,19 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
}
else
{
/* Grab some values */
ClientId = ReplyMessage->ClientId;
MessageId = ReplyMessage->MessageId;
ConnectionInfoLength = ReplyMessage->u1.s1.DataLength;
CapturedReplyMessage = *ReplyMessage;
ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
/* Validate the size of the server view */
if ((ServerView) && (ServerView->Length != sizeof(PORT_VIEW)))
if ((ServerView) && (ServerView->Length != sizeof(*ServerView)))
{
/* Invalid size */
return STATUS_INVALID_PARAMETER;
}
CapturedServerView = *ServerView;
/* Validate the size of the client view */
if ((ClientView) && (ClientView->Length != sizeof(REMOTE_PORT_VIEW)))
if ((ClientView) && (ClientView->Length != sizeof(*ClientView)))
{
/* Invalid size */
return STATUS_INVALID_PARAMETER;
@ -141,7 +138,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
}
/* Get the client process and thread */
Status = PsLookupProcessThreadByCid(&ClientId,
Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
&ClientProcess,
&ClientThread);
if (!NT_SUCCESS(Status)) return Status;
@ -151,8 +148,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* Make sure that the client wants a reply, and this is the right one */
if (!(LpcpGetMessageFromThread(ClientThread)) ||
!(MessageId) ||
(ClientThread->LpcReplyMessageId != MessageId))
!(CapturedReplyMessage.MessageId) ||
(ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId))
{
/* Not the reply asked for, or no reply wanted, fail */
KeReleaseGuardedMutex(&LpcpLock);
@ -203,8 +200,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* Setup the reply message */
Message->Request.u2.s2.Type = LPC_REPLY;
Message->Request.u2.s2.DataInfoOffset = 0;
Message->Request.ClientId = ClientId;
Message->Request.MessageId = MessageId;
Message->Request.ClientId = CapturedReplyMessage.ClientId;
Message->Request.MessageId = CapturedReplyMessage.MessageId;
Message->Request.ClientViewSize = 0;
_SEH2_TRY
@ -310,6 +307,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
if (ServerView)
{
/* FIXME: TODO */
UNREFERENCED_PARAMETER(CapturedServerView);
ASSERT(FALSE);
}
@ -347,7 +345,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Cleanup and return the exception code */
ObCloseHandle(Handle, UserMode);
ObCloseHandle(Handle, PreviousMode);
ObDereferenceObject(ServerPort);
Status = _SEH2_GetExceptionCode();
_SEH2_YIELD(goto Cleanup);

View file

@ -90,6 +90,9 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
NTSTATUS Status = STATUS_SUCCESS;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
PETHREAD Thread = PsGetCurrentThread();
SECURITY_QUALITY_OF_SERVICE CapturedQos;
PORT_VIEW CapturedClientView;
PSID CapturedServerSid;
ULONG ConnectionInfoLength = 0;
PLPCP_PORT_OBJECT Port, ClientPort;
PLPCP_MESSAGE Message;
@ -110,25 +113,122 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
ServerView,
ServerSid);
/* Validate client view */
if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW)))
/* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{
/* Fail */
return STATUS_INVALID_PARAMETER;
}
/* Enter SEH for probing the parameters */
_SEH2_TRY
{
/* Probe the PortHandle */
ProbeForWriteHandle(PortHandle);
/* Validate server view */
if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW)))
{
/* Fail */
return STATUS_INVALID_PARAMETER;
}
/* Probe and capture the QoS */
ProbeForRead(SecurityQos, sizeof(*SecurityQos), sizeof(ULONG));
CapturedQos = *(volatile SECURITY_QUALITY_OF_SERVICE*)SecurityQos;
/* NOTE: Do not care about CapturedQos.Length */
/* Check if caller sent connection information length */
if (ConnectionInformationLength)
/* The following parameters are optional */
/* Capture the client view */
if (ClientView != NULL)
{
ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
CapturedClientView = *(volatile PORT_VIEW*)ClientView;
/* Validate the size of the client view */
if (CapturedClientView.Length != sizeof(CapturedClientView))
{
/* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER);
}
}
/* Capture the server view */
if (ServerView != NULL)
{
ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
/* Validate the size of the server view */
if (((volatile REMOTE_PORT_VIEW*)ServerView)->Length != sizeof(*ServerView))
{
/* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER);
}
}
if (MaxMessageLength)
ProbeForWriteUlong(MaxMessageLength);
/* Capture connection information length */
if (ConnectionInformationLength)
{
ProbeForWriteUlong(ConnectionInformationLength);
ConnectionInfoLength = *(volatile ULONG*)ConnectionInformationLength;
}
/* Probe the ConnectionInformation */
if (ConnectionInformation)
ProbeForWrite(ConnectionInformation, ConnectionInfoLength, sizeof(ULONG));
CapturedServerSid = ServerSid;
if (ServerSid != NULL)
{
/* Capture it */
Status = SepCaptureSid(ServerSid,
PreviousMode,
PagedPool,
TRUE,
&CapturedServerSid);
if (!NT_SUCCESS(Status))
{
DPRINT1("Failed to capture ServerSid!\n");
_SEH2_YIELD(return Status);
}
}
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* There was an exception, return the exception code */
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
}
else
{
/* Retrieve the input length */
ConnectionInfoLength = *ConnectionInformationLength;
CapturedQos = *SecurityQos;
/* NOTE: Do not care about CapturedQos.Length */
/* The following parameters are optional */
/* Capture the client view */
if (ClientView != NULL)
{
/* Validate the size of the client view */
if (ClientView->Length != sizeof(*ClientView))
{
/* Invalid size */
return STATUS_INVALID_PARAMETER;
}
CapturedClientView = *ClientView;
}
/* Capture the server view */
if (ServerView != NULL)
{
/* Validate the size of the server view */
if (ServerView->Length != sizeof(*ServerView))
{
/* Invalid size */
return STATUS_INVALID_PARAMETER;
}
}
/* Capture connection information length */
if (ConnectionInformationLength)
ConnectionInfoLength = *ConnectionInformationLength;
CapturedServerSid = ServerSid;
}
/* Get the port */
@ -143,6 +243,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (!NT_SUCCESS(Status))
{
DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status);
if (CapturedServerSid != ServerSid)
SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
return Status;
}
@ -151,10 +255,14 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
{
/* It isn't, so fail */
ObDereferenceObject(Port);
if (CapturedServerSid != ServerSid)
SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
return STATUS_INVALID_PORT_HANDLE;
}
/* Check if we have a SID */
/* Check if we have a (captured) SID */
if (ServerSid)
{
/* Make sure that we have a server */
@ -162,18 +270,14 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
{
/* Get its token and query user information */
Token = PsReferencePrimaryToken(Port->ServerProcess);
//Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
// FIXME: Need SeQueryInformationToken
Status = STATUS_SUCCESS;
TokenUserInfo = ExAllocatePoolWithTag(PagedPool, sizeof(TOKEN_USER), TAG_SE);
TokenUserInfo->User.Sid = ServerSid;
Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
PsDereferencePrimaryToken(Token);
/* Check for success */
if (NT_SUCCESS(Status))
{
/* Compare the SIDs */
if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid))
if (!RtlEqualSid(CapturedServerSid, TokenUserInfo->User.Sid))
{
/* Fail */
Status = STATUS_SERVER_SID_MISMATCH;
@ -189,6 +293,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
Status = STATUS_SERVER_SID_MISMATCH;
}
/* Finally release the captured SID, we don't need it anymore */
if (CapturedServerSid != ServerSid)
SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
/* Check if SID failed */
if (!NT_SUCCESS(Status))
{
@ -215,17 +323,20 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
return Status;
}
/* Setup the client port */
/*
* Setup the client port -- From now on, dereferencing the client port
* will automatically dereference the connection port too.
*/
RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
ClientPort->Flags = LPCP_CLIENT_PORT;
ClientPort->ConnectionPort = Port;
ClientPort->MaxMessageLength = Port->MaxMessageLength;
ClientPort->SecurityQos = *Qos;
ClientPort->SecurityQos = CapturedQos;
InitializeListHead(&ClientPort->LpcReplyChainHead);
InitializeListHead(&ClientPort->LpcDataInfoChainHead);
/* Check if we have dynamic security */
if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
if (CapturedQos.ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
{
/* Remember that */
ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
@ -234,7 +345,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
{
/* Create our own client security */
Status = SeCreateClientSecurity(Thread,
Qos,
&CapturedQos,
FALSE,
&ClientPort->StaticSecurity);
if (!NT_SUCCESS(Status))
@ -258,7 +369,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (ClientView)
{
/* Get the section handle */
Status = ObReferenceObjectByHandle(ClientView->SectionHandle,
Status = ObReferenceObjectByHandle(CapturedClientView.SectionHandle,
SECTION_MAP_READ |
SECTION_MAP_WRITE,
MmSectionObjectType,
@ -268,12 +379,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (!NT_SUCCESS(Status))
{
/* Fail */
ObDereferenceObject(Port);
ObDereferenceObject(ClientPort);
return Status;
}
/* Set the section offset */
SectionOffset.QuadPart = ClientView->SectionOffset;
SectionOffset.QuadPart = CapturedClientView.SectionOffset;
/* Map it */
Status = MmMapViewOfSection(SectionToMap,
@ -282,25 +393,25 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
0,
0,
&SectionOffset,
&ClientView->ViewSize,
&CapturedClientView.ViewSize,
ViewUnmap,
0,
PAGE_READWRITE);
/* Update the offset */
ClientView->SectionOffset = SectionOffset.LowPart;
CapturedClientView.SectionOffset = SectionOffset.LowPart;
/* Check for failure */
if (!NT_SUCCESS(Status))
{
/* Fail */
ObDereferenceObject(SectionToMap);
ObDereferenceObject(Port);
ObDereferenceObject(ClientPort);
return Status;
}
/* Update the base */
ClientView->ViewBase = ClientPort->ClientSectionBase;
CapturedClientView.ViewBase = ClientPort->ClientSectionBase;
/* Reference and remember the process */
ClientPort->MappingProcess = PsGetCurrentProcess();
@ -337,12 +448,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (ClientView)
{
/* Set the view size */
Message->Request.ClientViewSize = ClientView->ViewSize;
Message->Request.ClientViewSize = CapturedClientView.ViewSize;
/* Copy the client view and clear the server view */
RtlCopyMemory(&ConnectMessage->ClientView,
ClientView,
sizeof(PORT_VIEW));
&CapturedClientView,
sizeof(CapturedClientView));
RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
}
else
@ -366,12 +477,33 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
/* Check if we have connection information */
if (ConnectionInformation)
{
/* Copy it in */
RtlCopyMemory(ConnectMessage + 1,
ConnectionInformation,
ConnectionInfoLength);
_SEH2_TRY
{
/* Copy it in */
RtlCopyMemory(ConnectMessage + 1,
ConnectionInformation,
ConnectionInfoLength);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Cleanup and return the exception code */
/* Free the message we have */
LpcpFreeToPortZone(Message, 0);
/* Dereference other objects */
if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort);
/* Return status */
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
}
/* Reset the status code */
Status = STATUS_SUCCESS;
/* Acquire the port lock */
KeAcquireGuardedMutex(&LpcpLock);
@ -433,12 +565,26 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
}
/* Check for failure */
if (!NT_SUCCESS(Status)) goto Cleanup;
/* Free the connection message */
/* Now, always free the connection message */
SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
/* Check for failure */
if (!NT_SUCCESS(Status))
{
/* Check if the semaphore got signaled in the meantime */
if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
{
/* Wait on it */
KeWaitForSingleObject(&Thread->LpcReplySemaphore,
WrExecutive,
KernelMode,
FALSE,
NULL);
}
goto Failure;
}
/* Check if we got a message back */
if (Message)
{
@ -451,20 +597,27 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
sizeof(LPCP_CONNECTION_MESSAGE);
}
/* Check if we had connection information */
/* Check if the caller had connection information */
if (ConnectionInformation)
{
/* Check if we had a length pointer */
if (ConnectionInformationLength)
_SEH2_TRY
{
/* Return the length */
*ConnectionInformationLength = ConnectionInfoLength;
}
/* Return the connection information length if needed */
if (ConnectionInformationLength)
*ConnectionInformationLength = ConnectionInfoLength;
/* Return the connection information */
RtlCopyMemory(ConnectionInformation,
ConnectMessage + 1,
ConnectionInfoLength );
/* Return the connection information */
RtlCopyMemory(ConnectionInformation,
ConnectMessage + 1,
ConnectionInfoLength);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Cleanup and return the exception code */
Status = _SEH2_GetExceptionCode();
_SEH2_YIELD(goto Failure);
}
_SEH2_END;
}
/* Make sure we had a connected port */
@ -482,33 +635,45 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
&Handle);
if (NT_SUCCESS(Status))
{
/* Return the handle */
*PortHandle = Handle;
LPCTRACE(LPC_CONNECT_DEBUG,
"Handle: %p. Length: %lx\n",
Handle,
PortMessageLength);
/* Check if maximum length was requested */
if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
/* Check if we had a client view */
if (ClientView)
_SEH2_TRY
{
/* Copy it back */
RtlCopyMemory(ClientView,
&ConnectMessage->ClientView,
sizeof(PORT_VIEW));
}
/* Return the handle */
*PortHandle = Handle;
/* Check if we had a server view */
if (ServerView)
{
/* Copy it back */
RtlCopyMemory(ServerView,
&ConnectMessage->ServerView,
sizeof(REMOTE_PORT_VIEW));
/* Check if maximum length was requested */
if (MaxMessageLength)
*MaxMessageLength = PortMessageLength;
/* Check if we had a client view */
if (ClientView)
{
/* Copy it back */
RtlCopyMemory(ClientView,
&ConnectMessage->ClientView,
sizeof(*ClientView));
}
/* Check if we had a server view */
if (ServerView)
{
/* Copy it back */
RtlCopyMemory(ServerView,
&ConnectMessage->ServerView,
sizeof(*ServerView));
}
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* An exception happened, close the opened handle */
ObCloseHandle(Handle, PreviousMode);
Status = _SEH2_GetExceptionCode();
}
_SEH2_END;
}
}
else
@ -545,39 +710,25 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
else
{
/* No reply message, fail */
if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort);
Status = STATUS_PORT_CONNECTION_REFUSED;
goto Failure;
}
ObDereferenceObject(Port);
/* Return status */
ObDereferenceObject(Port);
return Status;
Cleanup:
/* We failed, free the message */
SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
/* Check if the semaphore got signaled */
if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
{
/* Wait on it */
KeWaitForSingleObject(&Thread->LpcReplySemaphore,
WrExecutive,
KernelMode,
FALSE,
NULL);
}
Failure:
/* Check if we had a message and free it */
if (Message) LpcpFreeToPortZone(Message, 0);
/* Dereference other objects */
if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort);
ObDereferenceObject(Port);
/* Return status */
ObDereferenceObject(Port);
return Status;
}

View file

@ -49,14 +49,15 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
{
NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
UNICODE_STRING CapturedObjectName, *ObjectName;
PLPCP_PORT_OBJECT Port;
HANDLE Handle;
PUNICODE_STRING ObjectName;
BOOLEAN NoName;
PAGED_CODE();
LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName);
RtlInitEmptyUnicodeString(&CapturedObjectName, NULL, 0);
/* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{
@ -65,15 +66,14 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
/* Probe the PortHandle */
ProbeForWriteHandle(PortHandle);
/* Probe the ObjectAttributes */
ProbeForRead(ObjectAttributes, sizeof(OBJECT_ATTRIBUTES), sizeof(ULONG));
/* Get the object name and probe the unicode string */
ObjectName = ObjectAttributes->ObjectName;
ProbeForRead(ObjectName, sizeof(UNICODE_STRING), 1);
/* Check if we have no name */
NoName = (ObjectName->Buffer == NULL) || (ObjectName->Length == 0);
/* Probe the ObjectAttributes and its object name (not the buffer) */
ProbeForRead(ObjectAttributes, sizeof(*ObjectAttributes), sizeof(ULONG));
ObjectName = ((volatile OBJECT_ATTRIBUTES*)ObjectAttributes)->ObjectName;
if (ObjectName)
{
ProbeForRead(ObjectName, sizeof(*ObjectName), 1);
CapturedObjectName = *(volatile UNICODE_STRING*)ObjectName;
}
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
@ -84,11 +84,14 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
}
else
{
/* Check if we have no name */
NoName = (ObjectAttributes->ObjectName->Buffer == NULL) ||
(ObjectAttributes->ObjectName->Length == 0);
if (ObjectAttributes->ObjectName)
CapturedObjectName = *(ObjectAttributes->ObjectName);
}
/* Normalize the buffer pointer in case we don't have a name */
if (CapturedObjectName.Length == 0)
CapturedObjectName.Buffer = NULL;
/* Create the Object */
Status = ObCreateObject(PreviousMode,
LpcPortObjectType,
@ -109,7 +112,7 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
InitializeListHead(&Port->LpcReplyChainHead);
/* Check if we don't have a name */
if (NoName)
if (CapturedObjectName.Buffer == NULL)
{
/* Set up for an unconnected port */
Port->Flags = LPCP_UNCONNECTED_PORT;
@ -187,7 +190,8 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
ObCloseHandle(Handle, UserMode);
/* An exception happened, close the opened handle */
ObCloseHandle(Handle, PreviousMode);
Status = _SEH2_GetExceptionCode();
}
_SEH2_END;

View file

@ -36,13 +36,22 @@ NtListenPort(IN HANDLE PortHandle,
NULL,
ConnectMessage);
/* Accept only LPC_CONNECTION_REQUEST requests */
if ((Status != STATUS_SUCCESS) ||
(LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
_SEH2_TRY
{
/* Break out */
break;
/* Accept only LPC_CONNECTION_REQUEST requests */
if ((Status != STATUS_SUCCESS) ||
(LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
{
/* Break out */
_SEH2_YIELD(break);
}
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
Status = _SEH2_GetExceptionCode();
_SEH2_YIELD(break);
}
_SEH2_END;
}
/* Return status */

View file

@ -136,28 +136,26 @@ NtImpersonateClientOfPort(IN HANDLE PortHandle,
PAGED_CODE();
/* Check the previous mode */
PreviousMode = ExGetPreviousMode();
if (PreviousMode == KernelMode)
{
ClientId = ClientMessage->ClientId;
MessageId = ClientMessage->MessageId;
}
else
/* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{
_SEH2_TRY
{
ProbeForRead(ClientMessage, sizeof(*ClientMessage), sizeof(PVOID));
ClientId = ClientMessage->ClientId;
MessageId = ClientMessage->MessageId;
ClientId = ((volatile PORT_MESSAGE*)ClientMessage)->ClientId;
MessageId = ((volatile PORT_MESSAGE*)ClientMessage)->MessageId;
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
DPRINT1("Got exception!\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
}
else
{
ClientId = ClientMessage->ClientId;
MessageId = ClientMessage->MessageId;
}
/* Reference the port handle */
Status = ObReferenceObjectByHandle(PortHandle,

View file

@ -192,7 +192,7 @@ NtReplyPort(IN HANDLE PortHandle,
{
NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
// PORT_MESSAGE CapturedReplyMessage;
PORT_MESSAGE CapturedReplyMessage;
PLPCP_PORT_OBJECT Port;
PLPCP_MESSAGE Message;
PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
@ -203,32 +203,35 @@ NtReplyPort(IN HANDLE PortHandle,
PortHandle,
ReplyMessage);
if (KeGetPreviousMode() == UserMode)
/* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{
_SEH2_TRY
{
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
/*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
ReplyMessage = &CapturedReplyMessage;*/
ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
}
_SEH2_EXCEPT(ExSystemExceptionFilter())
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
DPRINT1("SEH crash [1]\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
}
else
{
CapturedReplyMessage = *ReplyMessage;
}
/* Validate its length */
if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)ReplyMessage->u1.s1.TotalLength)
if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)CapturedReplyMessage.u1.s1.TotalLength)
{
/* Fail */
return STATUS_INVALID_PARAMETER;
}
/* Make sure it has a valid ID */
if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
/* Get the Port object */
Status = ObReferenceObjectByHandle(PortHandle,
@ -240,9 +243,9 @@ NtReplyPort(IN HANDLE PortHandle,
if (!NT_SUCCESS(Status)) return Status;
/* Validate its length in respect to the port object */
if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)ReplyMessage->u1.s1.TotalLength <=
(ULONG)ReplyMessage->u1.s1.DataLength))
if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
(ULONG)CapturedReplyMessage.u1.s1.DataLength))
{
/* Too large, fail */
ObDereferenceObject(Port);
@ -250,7 +253,7 @@ NtReplyPort(IN HANDLE PortHandle,
}
/* Get the ETHREAD corresponding to it */
Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
NULL,
&WakeupThread);
if (!NT_SUCCESS(Status))
@ -274,7 +277,7 @@ NtReplyPort(IN HANDLE PortHandle,
KeAcquireGuardedMutex(&LpcpLock);
/* Make sure this is the reply the thread is waiting for */
if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
((LpcpGetMessageFromThread(WakeupThread)) &&
(LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)-> Request)
!= LPC_REQUEST)))
@ -290,7 +293,7 @@ NtReplyPort(IN HANDLE PortHandle,
_SEH2_TRY
{
LpcpMoveMessage(&Message->Request,
ReplyMessage,
&CapturedReplyMessage,
ReplyMessage + 1,
LPC_REPLY,
NULL);
@ -324,7 +327,7 @@ NtReplyPort(IN HANDLE PortHandle,
/* Check if this is the message the thread had received */
if ((Thread->LpcReceivedMsgIdValid) &&
(Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
(Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
{
/* Clear this data */
Thread->LpcReceivedMessageId = 0;
@ -333,9 +336,9 @@ NtReplyPort(IN HANDLE PortHandle,
/* Free any data information */
LpcpFreeDataInfoMessage(Port,
ReplyMessage->MessageId,
ReplyMessage->CallbackId,
ReplyMessage->ClientId);
CapturedReplyMessage.MessageId,
CapturedReplyMessage.CallbackId,
CapturedReplyMessage.ClientId);
/* Release the lock and release the LPC semaphore to wake up waiters */
KeReleaseGuardedMutex(&LpcpLock);
@ -362,7 +365,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
{
NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
// PORT_MESSAGE CapturedReplyMessage;
PORT_MESSAGE CapturedReplyMessage;
LARGE_INTEGER CapturedTimeout;
PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
PLPCP_MESSAGE Message;
@ -378,30 +381,29 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
ReceiveMessage,
PortContext);
if (KeGetPreviousMode() == UserMode)
/* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{
_SEH2_TRY
{
if (PortContext != NULL)
ProbeForWritePointer(PortContext);
if (ReplyMessage != NULL)
{
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
/*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
ReplyMessage = &CapturedReplyMessage;*/
ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
}
if (Timeout != NULL)
{
ProbeForReadLargeInteger(Timeout);
RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER));
CapturedTimeout = *(volatile LARGE_INTEGER*)Timeout;
Timeout = &CapturedTimeout;
}
if (PortContext != NULL)
ProbeForWritePointer(PortContext);
}
_SEH2_EXCEPT(ExSystemExceptionFilter())
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
DPRINT1("SEH crash [1]\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
@ -410,21 +412,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
{
/* If this is a system thread, then let it page out its stack */
if (Thread->SystemThread) WaitMode = UserMode;
if (ReplyMessage != NULL)
CapturedReplyMessage = *ReplyMessage;
}
/* Check if caller has a reply message */
if (ReplyMessage)
{
/* Validate its length */
if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)ReplyMessage->u1.s1.TotalLength)
if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)CapturedReplyMessage.u1.s1.TotalLength)
{
/* Fail */
return STATUS_INVALID_PARAMETER;
}
/* Make sure it has a valid ID */
if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
}
/* Get the Port object */
@ -440,9 +445,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
if (ReplyMessage)
{
/* Validate its length in respect to the port object */
if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)ReplyMessage->u1.s1.TotalLength <=
(ULONG)ReplyMessage->u1.s1.DataLength))
if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
(ULONG)CapturedReplyMessage.u1.s1.DataLength))
{
/* Too large, fail */
ObDereferenceObject(Port);
@ -490,7 +495,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
if (ReplyMessage)
{
/* Get the ETHREAD corresponding to it */
Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
NULL,
&WakeupThread);
if (!NT_SUCCESS(Status))
@ -516,7 +521,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
KeAcquireGuardedMutex(&LpcpLock);
/* Make sure this is the reply the thread is waiting for */
if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
((LpcpGetMessageFromThread(WakeupThread)) &&
(LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->Request)
!= LPC_REQUEST)))
@ -530,11 +535,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
}
/* Copy the message */
LpcpMoveMessage(&Message->Request,
ReplyMessage,
ReplyMessage + 1,
LPC_REPLY,
NULL);
_SEH2_TRY
{
LpcpMoveMessage(&Message->Request,
&CapturedReplyMessage,
ReplyMessage + 1,
LPC_REPLY,
NULL);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Cleanup and return the exception code */
LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(WakeupThread);
ObDereferenceObject(Port);
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
/* Reference the thread while we use it */
ObReferenceObject(WakeupThread);
@ -555,7 +573,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
/* Check if this is the message the thread had received */
if ((Thread->LpcReceivedMsgIdValid) &&
(Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
(Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
{
/* Clear this data */
Thread->LpcReceivedMessageId = 0;
@ -564,9 +582,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
/* Free any data information */
LpcpFreeDataInfoMessage(Port,
ReplyMessage->MessageId,
ReplyMessage->CallbackId,
ReplyMessage->ClientId);
CapturedReplyMessage.MessageId,
CapturedReplyMessage.CallbackId,
CapturedReplyMessage.ClientId);
/* Release the lock and release the LPC semaphore to wake up waiters */
KeReleaseGuardedMutex(&LpcpLock);
@ -688,9 +706,8 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
ASSERT(FALSE);
}
}
_SEH2_EXCEPT(ExSystemExceptionFilter())
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
DPRINT1("SEH crash [2]\n");
Status = _SEH2_GetExceptionCode();
}
_SEH2_END;
@ -771,26 +788,24 @@ LpcpCopyRequestData(
PAGED_CODE();
/* Check the previous mode */
PreviousMode = ExGetPreviousMode();
if (PreviousMode == KernelMode)
{
CapturedMessage = *Message;
}
else
/* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{
_SEH2_TRY
{
ProbeForRead(Message, sizeof(*Message), sizeof(PVOID));
CapturedMessage = *Message;
CapturedMessage = *(volatile PORT_MESSAGE*)Message;
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
DPRINT1("Got exception!\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
}
else
{
CapturedMessage = *Message;
}
/* Make sure there is any data to copy */
if (CapturedMessage.u2.s2.DataInfoOffset == 0)

View file

@ -461,9 +461,8 @@ NtRequestPort(IN HANDLE PortHandle,
_SEH2_TRY
{
/* Probe and capture the LpcRequest */
ProbeForRead(LpcRequest, sizeof(PORT_MESSAGE), sizeof(ULONG));
ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG));
CapturedLpcRequest = *LpcRequest;
ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
@ -523,7 +522,7 @@ NtRequestPort(IN HANDLE PortHandle,
{
/* Copy it */
LpcpMoveMessage(&Message->Request,
LpcRequest,
&CapturedLpcRequest,
LpcRequest + 1,
MessageType,
&Thread->Cid);
@ -725,10 +724,9 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
{
_SEH2_TRY
{
/* Probe the full request message and copy the base structure */
/* Probe and capture the LpcRequest */
ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG));
CapturedLpcRequest = *LpcRequest;
CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
/* Probe the reply message for write */
ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG));