[NTOS:LPC]

- Fix the usage, or add support (for NtSecureConnectPort based on patches by Alexander Andrejevic, and for NtReplyPort, NtReplyWaitReceivePortEx, NtListenPort by me) for capturing user-mode parameters and using SEH in LPC functions.
CORE-7371 #resolve
- Make NtSecureConnectPort call SeQueryInformationToken, now that the latter is implemented since r73122.
- Fix ObDereferenceObject usage for Port vs. ClientPort in NtSecureConnectPort.
- ObCloseHandle certainly needs to be called with the actual 'PreviousMode'.

svn path=/trunk/; revision=73164
This commit is contained in:
Hermès Bélusca-Maïto 2016-11-07 01:24:24 +00:00
parent a066c5bbb0
commit f197af7358
7 changed files with 388 additions and 215 deletions

View file

@ -46,6 +46,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
{ {
NTSTATUS Status; NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
PORT_VIEW CapturedServerView;
PORT_MESSAGE CapturedReplyMessage;
ULONG ConnectionInfoLength; ULONG ConnectionInfoLength;
PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort; PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort;
PLPCP_CONNECTION_MESSAGE ConnectMessage; PLPCP_CONNECTION_MESSAGE ConnectMessage;
@ -55,8 +57,6 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
PEPROCESS ClientProcess; PEPROCESS ClientProcess;
PETHREAD ClientThread; PETHREAD ClientThread;
LARGE_INTEGER SectionOffset; LARGE_INTEGER SectionOffset;
CLIENT_ID ClientId;
ULONG MessageId;
PAGED_CODE(); PAGED_CODE();
LPCTRACE(LPC_COMPLETE_DEBUG, LPCTRACE(LPC_COMPLETE_DEBUG,
@ -70,18 +70,15 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* Check if the call comes from user mode */ /* Check if the call comes from user mode */
if (PreviousMode != KernelMode) if (PreviousMode != KernelMode)
{ {
/* Enter SEH for probing the parameters */
_SEH2_TRY _SEH2_TRY
{ {
/* Probe the PortHandle */
ProbeForWriteHandle(PortHandle); ProbeForWriteHandle(PortHandle);
/* Probe the basic ReplyMessage structure */ /* Probe the basic ReplyMessage structure */
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
/* Grab some values */ ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
ClientId = ReplyMessage->ClientId;
MessageId = ReplyMessage->MessageId;
ConnectionInfoLength = ReplyMessage->u1.s1.DataLength;
/* Probe the connection info */ /* Probe the connection info */
ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1); ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1);
@ -89,10 +86,11 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* The following parameters are optional */ /* The following parameters are optional */
if (ServerView != NULL) if (ServerView != NULL)
{ {
ProbeForWrite(ServerView, sizeof(PORT_VIEW), sizeof(ULONG)); ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
CapturedServerView = *(volatile PORT_VIEW*)ServerView;
/* Validate the size of the server view */ /* Validate the size of the server view */
if (ServerView->Length != sizeof(PORT_VIEW)) if (CapturedServerView.Length != sizeof(CapturedServerView))
{ {
/* Invalid size */ /* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER); _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
@ -101,10 +99,10 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
if (ClientView != NULL) if (ClientView != NULL)
{ {
ProbeForWrite(ClientView, sizeof(REMOTE_PORT_VIEW), sizeof(ULONG)); ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
/* Validate the size of the client view */ /* Validate the size of the client view */
if (ClientView->Length != sizeof(REMOTE_PORT_VIEW)) if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView))
{ {
/* Invalid size */ /* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER); _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
@ -120,20 +118,19 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
} }
else else
{ {
/* Grab some values */ CapturedReplyMessage = *ReplyMessage;
ClientId = ReplyMessage->ClientId; ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength;
MessageId = ReplyMessage->MessageId;
ConnectionInfoLength = ReplyMessage->u1.s1.DataLength;
/* Validate the size of the server view */ /* Validate the size of the server view */
if ((ServerView) && (ServerView->Length != sizeof(PORT_VIEW))) if ((ServerView) && (ServerView->Length != sizeof(*ServerView)))
{ {
/* Invalid size */ /* Invalid size */
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
CapturedServerView = *ServerView;
/* Validate the size of the client view */ /* Validate the size of the client view */
if ((ClientView) && (ClientView->Length != sizeof(REMOTE_PORT_VIEW))) if ((ClientView) && (ClientView->Length != sizeof(*ClientView)))
{ {
/* Invalid size */ /* Invalid size */
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
@ -141,7 +138,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
} }
/* Get the client process and thread */ /* Get the client process and thread */
Status = PsLookupProcessThreadByCid(&ClientId, Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
&ClientProcess, &ClientProcess,
&ClientThread); &ClientThread);
if (!NT_SUCCESS(Status)) return Status; if (!NT_SUCCESS(Status)) return Status;
@ -151,8 +148,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* Make sure that the client wants a reply, and this is the right one */ /* Make sure that the client wants a reply, and this is the right one */
if (!(LpcpGetMessageFromThread(ClientThread)) || if (!(LpcpGetMessageFromThread(ClientThread)) ||
!(MessageId) || !(CapturedReplyMessage.MessageId) ||
(ClientThread->LpcReplyMessageId != MessageId)) (ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId))
{ {
/* Not the reply asked for, or no reply wanted, fail */ /* Not the reply asked for, or no reply wanted, fail */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
@ -203,8 +200,8 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
/* Setup the reply message */ /* Setup the reply message */
Message->Request.u2.s2.Type = LPC_REPLY; Message->Request.u2.s2.Type = LPC_REPLY;
Message->Request.u2.s2.DataInfoOffset = 0; Message->Request.u2.s2.DataInfoOffset = 0;
Message->Request.ClientId = ClientId; Message->Request.ClientId = CapturedReplyMessage.ClientId;
Message->Request.MessageId = MessageId; Message->Request.MessageId = CapturedReplyMessage.MessageId;
Message->Request.ClientViewSize = 0; Message->Request.ClientViewSize = 0;
_SEH2_TRY _SEH2_TRY
@ -310,6 +307,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
if (ServerView) if (ServerView)
{ {
/* FIXME: TODO */ /* FIXME: TODO */
UNREFERENCED_PARAMETER(CapturedServerView);
ASSERT(FALSE); ASSERT(FALSE);
} }
@ -347,7 +345,7 @@ NtAcceptConnectPort(OUT PHANDLE PortHandle,
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
/* Cleanup and return the exception code */ /* Cleanup and return the exception code */
ObCloseHandle(Handle, UserMode); ObCloseHandle(Handle, PreviousMode);
ObDereferenceObject(ServerPort); ObDereferenceObject(ServerPort);
Status = _SEH2_GetExceptionCode(); Status = _SEH2_GetExceptionCode();
_SEH2_YIELD(goto Cleanup); _SEH2_YIELD(goto Cleanup);

View file

@ -90,6 +90,9 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
NTSTATUS Status = STATUS_SUCCESS; NTSTATUS Status = STATUS_SUCCESS;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
PETHREAD Thread = PsGetCurrentThread(); PETHREAD Thread = PsGetCurrentThread();
SECURITY_QUALITY_OF_SERVICE CapturedQos;
PORT_VIEW CapturedClientView;
PSID CapturedServerSid;
ULONG ConnectionInfoLength = 0; ULONG ConnectionInfoLength = 0;
PLPCP_PORT_OBJECT Port, ClientPort; PLPCP_PORT_OBJECT Port, ClientPort;
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
@ -110,25 +113,122 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
ServerView, ServerView,
ServerSid); ServerSid);
/* Validate client view */ /* Check if the call comes from user mode */
if ((ClientView) && (ClientView->Length != sizeof(PORT_VIEW))) if (PreviousMode != KernelMode)
{ {
/* Fail */ /* Enter SEH for probing the parameters */
return STATUS_INVALID_PARAMETER; _SEH2_TRY
} {
/* Probe the PortHandle */
ProbeForWriteHandle(PortHandle);
/* Validate server view */ /* Probe and capture the QoS */
if ((ServerView) && (ServerView->Length != sizeof(REMOTE_PORT_VIEW))) ProbeForRead(SecurityQos, sizeof(*SecurityQos), sizeof(ULONG));
{ CapturedQos = *(volatile SECURITY_QUALITY_OF_SERVICE*)SecurityQos;
/* Fail */ /* NOTE: Do not care about CapturedQos.Length */
return STATUS_INVALID_PARAMETER;
}
/* Check if caller sent connection information length */ /* The following parameters are optional */
if (ConnectionInformationLength)
/* Capture the client view */
if (ClientView != NULL)
{
ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG));
CapturedClientView = *(volatile PORT_VIEW*)ClientView;
/* Validate the size of the client view */
if (CapturedClientView.Length != sizeof(CapturedClientView))
{
/* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER);
}
}
/* Capture the server view */
if (ServerView != NULL)
{
ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG));
/* Validate the size of the server view */
if (((volatile REMOTE_PORT_VIEW*)ServerView)->Length != sizeof(*ServerView))
{
/* Invalid size */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER);
}
}
if (MaxMessageLength)
ProbeForWriteUlong(MaxMessageLength);
/* Capture connection information length */
if (ConnectionInformationLength)
{
ProbeForWriteUlong(ConnectionInformationLength);
ConnectionInfoLength = *(volatile ULONG*)ConnectionInformationLength;
}
/* Probe the ConnectionInformation */
if (ConnectionInformation)
ProbeForWrite(ConnectionInformation, ConnectionInfoLength, sizeof(ULONG));
CapturedServerSid = ServerSid;
if (ServerSid != NULL)
{
/* Capture it */
Status = SepCaptureSid(ServerSid,
PreviousMode,
PagedPool,
TRUE,
&CapturedServerSid);
if (!NT_SUCCESS(Status))
{
DPRINT1("Failed to capture ServerSid!\n");
_SEH2_YIELD(return Status);
}
}
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* There was an exception, return the exception code */
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
}
else
{ {
/* Retrieve the input length */ CapturedQos = *SecurityQos;
ConnectionInfoLength = *ConnectionInformationLength; /* NOTE: Do not care about CapturedQos.Length */
/* The following parameters are optional */
/* Capture the client view */
if (ClientView != NULL)
{
/* Validate the size of the client view */
if (ClientView->Length != sizeof(*ClientView))
{
/* Invalid size */
return STATUS_INVALID_PARAMETER;
}
CapturedClientView = *ClientView;
}
/* Capture the server view */
if (ServerView != NULL)
{
/* Validate the size of the server view */
if (ServerView->Length != sizeof(*ServerView))
{
/* Invalid size */
return STATUS_INVALID_PARAMETER;
}
}
/* Capture connection information length */
if (ConnectionInformationLength)
ConnectionInfoLength = *ConnectionInformationLength;
CapturedServerSid = ServerSid;
} }
/* Get the port */ /* Get the port */
@ -143,6 +243,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status); DPRINT1("Failed to reference port '%wZ': 0x%lx\n", PortName, Status);
if (CapturedServerSid != ServerSid)
SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
return Status; return Status;
} }
@ -151,10 +255,14 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
{ {
/* It isn't, so fail */ /* It isn't, so fail */
ObDereferenceObject(Port); ObDereferenceObject(Port);
if (CapturedServerSid != ServerSid)
SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
return STATUS_INVALID_PORT_HANDLE; return STATUS_INVALID_PORT_HANDLE;
} }
/* Check if we have a SID */ /* Check if we have a (captured) SID */
if (ServerSid) if (ServerSid)
{ {
/* Make sure that we have a server */ /* Make sure that we have a server */
@ -162,18 +270,14 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
{ {
/* Get its token and query user information */ /* Get its token and query user information */
Token = PsReferencePrimaryToken(Port->ServerProcess); Token = PsReferencePrimaryToken(Port->ServerProcess);
//Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo); Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo);
// FIXME: Need SeQueryInformationToken
Status = STATUS_SUCCESS;
TokenUserInfo = ExAllocatePoolWithTag(PagedPool, sizeof(TOKEN_USER), TAG_SE);
TokenUserInfo->User.Sid = ServerSid;
PsDereferencePrimaryToken(Token); PsDereferencePrimaryToken(Token);
/* Check for success */ /* Check for success */
if (NT_SUCCESS(Status)) if (NT_SUCCESS(Status))
{ {
/* Compare the SIDs */ /* Compare the SIDs */
if (!RtlEqualSid(ServerSid, TokenUserInfo->User.Sid)) if (!RtlEqualSid(CapturedServerSid, TokenUserInfo->User.Sid))
{ {
/* Fail */ /* Fail */
Status = STATUS_SERVER_SID_MISMATCH; Status = STATUS_SERVER_SID_MISMATCH;
@ -189,6 +293,10 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
Status = STATUS_SERVER_SID_MISMATCH; Status = STATUS_SERVER_SID_MISMATCH;
} }
/* Finally release the captured SID, we don't need it anymore */
if (CapturedServerSid != ServerSid)
SepReleaseSid(CapturedServerSid, PreviousMode, TRUE);
/* Check if SID failed */ /* Check if SID failed */
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
@ -215,17 +323,20 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
return Status; return Status;
} }
/* Setup the client port */ /*
* Setup the client port -- From now on, dereferencing the client port
* will automatically dereference the connection port too.
*/
RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT)); RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT));
ClientPort->Flags = LPCP_CLIENT_PORT; ClientPort->Flags = LPCP_CLIENT_PORT;
ClientPort->ConnectionPort = Port; ClientPort->ConnectionPort = Port;
ClientPort->MaxMessageLength = Port->MaxMessageLength; ClientPort->MaxMessageLength = Port->MaxMessageLength;
ClientPort->SecurityQos = *Qos; ClientPort->SecurityQos = CapturedQos;
InitializeListHead(&ClientPort->LpcReplyChainHead); InitializeListHead(&ClientPort->LpcReplyChainHead);
InitializeListHead(&ClientPort->LpcDataInfoChainHead); InitializeListHead(&ClientPort->LpcDataInfoChainHead);
/* Check if we have dynamic security */ /* Check if we have dynamic security */
if (Qos->ContextTrackingMode == SECURITY_DYNAMIC_TRACKING) if (CapturedQos.ContextTrackingMode == SECURITY_DYNAMIC_TRACKING)
{ {
/* Remember that */ /* Remember that */
ClientPort->Flags |= LPCP_SECURITY_DYNAMIC; ClientPort->Flags |= LPCP_SECURITY_DYNAMIC;
@ -234,7 +345,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
{ {
/* Create our own client security */ /* Create our own client security */
Status = SeCreateClientSecurity(Thread, Status = SeCreateClientSecurity(Thread,
Qos, &CapturedQos,
FALSE, FALSE,
&ClientPort->StaticSecurity); &ClientPort->StaticSecurity);
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
@ -258,7 +369,7 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (ClientView) if (ClientView)
{ {
/* Get the section handle */ /* Get the section handle */
Status = ObReferenceObjectByHandle(ClientView->SectionHandle, Status = ObReferenceObjectByHandle(CapturedClientView.SectionHandle,
SECTION_MAP_READ | SECTION_MAP_READ |
SECTION_MAP_WRITE, SECTION_MAP_WRITE,
MmSectionObjectType, MmSectionObjectType,
@ -268,12 +379,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
/* Fail */ /* Fail */
ObDereferenceObject(Port); ObDereferenceObject(ClientPort);
return Status; return Status;
} }
/* Set the section offset */ /* Set the section offset */
SectionOffset.QuadPart = ClientView->SectionOffset; SectionOffset.QuadPart = CapturedClientView.SectionOffset;
/* Map it */ /* Map it */
Status = MmMapViewOfSection(SectionToMap, Status = MmMapViewOfSection(SectionToMap,
@ -282,25 +393,25 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
0, 0,
0, 0,
&SectionOffset, &SectionOffset,
&ClientView->ViewSize, &CapturedClientView.ViewSize,
ViewUnmap, ViewUnmap,
0, 0,
PAGE_READWRITE); PAGE_READWRITE);
/* Update the offset */ /* Update the offset */
ClientView->SectionOffset = SectionOffset.LowPart; CapturedClientView.SectionOffset = SectionOffset.LowPart;
/* Check for failure */ /* Check for failure */
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
/* Fail */ /* Fail */
ObDereferenceObject(SectionToMap); ObDereferenceObject(SectionToMap);
ObDereferenceObject(Port); ObDereferenceObject(ClientPort);
return Status; return Status;
} }
/* Update the base */ /* Update the base */
ClientView->ViewBase = ClientPort->ClientSectionBase; CapturedClientView.ViewBase = ClientPort->ClientSectionBase;
/* Reference and remember the process */ /* Reference and remember the process */
ClientPort->MappingProcess = PsGetCurrentProcess(); ClientPort->MappingProcess = PsGetCurrentProcess();
@ -337,12 +448,12 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
if (ClientView) if (ClientView)
{ {
/* Set the view size */ /* Set the view size */
Message->Request.ClientViewSize = ClientView->ViewSize; Message->Request.ClientViewSize = CapturedClientView.ViewSize;
/* Copy the client view and clear the server view */ /* Copy the client view and clear the server view */
RtlCopyMemory(&ConnectMessage->ClientView, RtlCopyMemory(&ConnectMessage->ClientView,
ClientView, &CapturedClientView,
sizeof(PORT_VIEW)); sizeof(CapturedClientView));
RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW)); RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW));
} }
else else
@ -366,12 +477,33 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
/* Check if we have connection information */ /* Check if we have connection information */
if (ConnectionInformation) if (ConnectionInformation)
{ {
/* Copy it in */ _SEH2_TRY
RtlCopyMemory(ConnectMessage + 1, {
ConnectionInformation, /* Copy it in */
ConnectionInfoLength); RtlCopyMemory(ConnectMessage + 1,
ConnectionInformation,
ConnectionInfoLength);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Cleanup and return the exception code */
/* Free the message we have */
LpcpFreeToPortZone(Message, 0);
/* Dereference other objects */
if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort);
/* Return status */
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
} }
/* Reset the status code */
Status = STATUS_SUCCESS;
/* Acquire the port lock */ /* Acquire the port lock */
KeAcquireGuardedMutex(&LpcpLock); KeAcquireGuardedMutex(&LpcpLock);
@ -433,12 +565,26 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode); LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode);
} }
/* Check for failure */ /* Now, always free the connection message */
if (!NT_SUCCESS(Status)) goto Cleanup;
/* Free the connection message */
SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
/* Check for failure */
if (!NT_SUCCESS(Status))
{
/* Check if the semaphore got signaled in the meantime */
if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
{
/* Wait on it */
KeWaitForSingleObject(&Thread->LpcReplySemaphore,
WrExecutive,
KernelMode,
FALSE,
NULL);
}
goto Failure;
}
/* Check if we got a message back */ /* Check if we got a message back */
if (Message) if (Message)
{ {
@ -451,20 +597,27 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
sizeof(LPCP_CONNECTION_MESSAGE); sizeof(LPCP_CONNECTION_MESSAGE);
} }
/* Check if we had connection information */ /* Check if the caller had connection information */
if (ConnectionInformation) if (ConnectionInformation)
{ {
/* Check if we had a length pointer */ _SEH2_TRY
if (ConnectionInformationLength)
{ {
/* Return the length */ /* Return the connection information length if needed */
*ConnectionInformationLength = ConnectionInfoLength; if (ConnectionInformationLength)
} *ConnectionInformationLength = ConnectionInfoLength;
/* Return the connection information */ /* Return the connection information */
RtlCopyMemory(ConnectionInformation, RtlCopyMemory(ConnectionInformation,
ConnectMessage + 1, ConnectMessage + 1,
ConnectionInfoLength ); ConnectionInfoLength);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Cleanup and return the exception code */
Status = _SEH2_GetExceptionCode();
_SEH2_YIELD(goto Failure);
}
_SEH2_END;
} }
/* Make sure we had a connected port */ /* Make sure we had a connected port */
@ -482,33 +635,45 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
&Handle); &Handle);
if (NT_SUCCESS(Status)) if (NT_SUCCESS(Status))
{ {
/* Return the handle */
*PortHandle = Handle;
LPCTRACE(LPC_CONNECT_DEBUG, LPCTRACE(LPC_CONNECT_DEBUG,
"Handle: %p. Length: %lx\n", "Handle: %p. Length: %lx\n",
Handle, Handle,
PortMessageLength); PortMessageLength);
/* Check if maximum length was requested */ _SEH2_TRY
if (MaxMessageLength) *MaxMessageLength = PortMessageLength;
/* Check if we had a client view */
if (ClientView)
{ {
/* Copy it back */ /* Return the handle */
RtlCopyMemory(ClientView, *PortHandle = Handle;
&ConnectMessage->ClientView,
sizeof(PORT_VIEW));
}
/* Check if we had a server view */ /* Check if maximum length was requested */
if (ServerView) if (MaxMessageLength)
{ *MaxMessageLength = PortMessageLength;
/* Copy it back */
RtlCopyMemory(ServerView, /* Check if we had a client view */
&ConnectMessage->ServerView, if (ClientView)
sizeof(REMOTE_PORT_VIEW)); {
/* Copy it back */
RtlCopyMemory(ClientView,
&ConnectMessage->ClientView,
sizeof(*ClientView));
}
/* Check if we had a server view */
if (ServerView)
{
/* Copy it back */
RtlCopyMemory(ServerView,
&ConnectMessage->ServerView,
sizeof(*ServerView));
}
} }
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* An exception happened, close the opened handle */
ObCloseHandle(Handle, PreviousMode);
Status = _SEH2_GetExceptionCode();
}
_SEH2_END;
} }
} }
else else
@ -545,39 +710,25 @@ NtSecureConnectPort(OUT PHANDLE PortHandle,
else else
{ {
/* No reply message, fail */ /* No reply message, fail */
if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort);
Status = STATUS_PORT_CONNECTION_REFUSED; Status = STATUS_PORT_CONNECTION_REFUSED;
goto Failure;
} }
ObDereferenceObject(Port);
/* Return status */ /* Return status */
ObDereferenceObject(Port);
return Status; return Status;
Cleanup: Failure:
/* We failed, free the message */
SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread);
/* Check if the semaphore got signaled */
if (KeReadStateSemaphore(&Thread->LpcReplySemaphore))
{
/* Wait on it */
KeWaitForSingleObject(&Thread->LpcReplySemaphore,
WrExecutive,
KernelMode,
FALSE,
NULL);
}
/* Check if we had a message and free it */ /* Check if we had a message and free it */
if (Message) LpcpFreeToPortZone(Message, 0); if (Message) LpcpFreeToPortZone(Message, 0);
/* Dereference other objects */ /* Dereference other objects */
if (SectionToMap) ObDereferenceObject(SectionToMap); if (SectionToMap) ObDereferenceObject(SectionToMap);
ObDereferenceObject(ClientPort); ObDereferenceObject(ClientPort);
ObDereferenceObject(Port);
/* Return status */ /* Return status */
ObDereferenceObject(Port);
return Status; return Status;
} }

View file

@ -49,14 +49,15 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
{ {
NTSTATUS Status; NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
UNICODE_STRING CapturedObjectName, *ObjectName;
PLPCP_PORT_OBJECT Port; PLPCP_PORT_OBJECT Port;
HANDLE Handle; HANDLE Handle;
PUNICODE_STRING ObjectName;
BOOLEAN NoName;
PAGED_CODE(); PAGED_CODE();
LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName); LPCTRACE(LPC_CREATE_DEBUG, "Name: %wZ\n", ObjectAttributes->ObjectName);
RtlInitEmptyUnicodeString(&CapturedObjectName, NULL, 0);
/* Check if the call comes from user mode */ /* Check if the call comes from user mode */
if (PreviousMode != KernelMode) if (PreviousMode != KernelMode)
{ {
@ -65,15 +66,14 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
/* Probe the PortHandle */ /* Probe the PortHandle */
ProbeForWriteHandle(PortHandle); ProbeForWriteHandle(PortHandle);
/* Probe the ObjectAttributes */ /* Probe the ObjectAttributes and its object name (not the buffer) */
ProbeForRead(ObjectAttributes, sizeof(OBJECT_ATTRIBUTES), sizeof(ULONG)); ProbeForRead(ObjectAttributes, sizeof(*ObjectAttributes), sizeof(ULONG));
ObjectName = ((volatile OBJECT_ATTRIBUTES*)ObjectAttributes)->ObjectName;
/* Get the object name and probe the unicode string */ if (ObjectName)
ObjectName = ObjectAttributes->ObjectName; {
ProbeForRead(ObjectName, sizeof(UNICODE_STRING), 1); ProbeForRead(ObjectName, sizeof(*ObjectName), 1);
CapturedObjectName = *(volatile UNICODE_STRING*)ObjectName;
/* Check if we have no name */ }
NoName = (ObjectName->Buffer == NULL) || (ObjectName->Length == 0);
} }
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
@ -84,11 +84,14 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
} }
else else
{ {
/* Check if we have no name */ if (ObjectAttributes->ObjectName)
NoName = (ObjectAttributes->ObjectName->Buffer == NULL) || CapturedObjectName = *(ObjectAttributes->ObjectName);
(ObjectAttributes->ObjectName->Length == 0);
} }
/* Normalize the buffer pointer in case we don't have a name */
if (CapturedObjectName.Length == 0)
CapturedObjectName.Buffer = NULL;
/* Create the Object */ /* Create the Object */
Status = ObCreateObject(PreviousMode, Status = ObCreateObject(PreviousMode,
LpcPortObjectType, LpcPortObjectType,
@ -109,7 +112,7 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
InitializeListHead(&Port->LpcReplyChainHead); InitializeListHead(&Port->LpcReplyChainHead);
/* Check if we don't have a name */ /* Check if we don't have a name */
if (NoName) if (CapturedObjectName.Buffer == NULL)
{ {
/* Set up for an unconnected port */ /* Set up for an unconnected port */
Port->Flags = LPCP_UNCONNECTED_PORT; Port->Flags = LPCP_UNCONNECTED_PORT;
@ -187,7 +190,8 @@ LpcpCreatePort(OUT PHANDLE PortHandle,
} }
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
ObCloseHandle(Handle, UserMode); /* An exception happened, close the opened handle */
ObCloseHandle(Handle, PreviousMode);
Status = _SEH2_GetExceptionCode(); Status = _SEH2_GetExceptionCode();
} }
_SEH2_END; _SEH2_END;

View file

@ -36,13 +36,22 @@ NtListenPort(IN HANDLE PortHandle,
NULL, NULL,
ConnectMessage); ConnectMessage);
/* Accept only LPC_CONNECTION_REQUEST requests */ _SEH2_TRY
if ((Status != STATUS_SUCCESS) ||
(LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
{ {
/* Break out */ /* Accept only LPC_CONNECTION_REQUEST requests */
break; if ((Status != STATUS_SUCCESS) ||
(LpcpGetMessageType(ConnectMessage) == LPC_CONNECTION_REQUEST))
{
/* Break out */
_SEH2_YIELD(break);
}
} }
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
Status = _SEH2_GetExceptionCode();
_SEH2_YIELD(break);
}
_SEH2_END;
} }
/* Return status */ /* Return status */

View file

@ -136,28 +136,26 @@ NtImpersonateClientOfPort(IN HANDLE PortHandle,
PAGED_CODE(); PAGED_CODE();
/* Check the previous mode */ /* Check if the call comes from user mode */
PreviousMode = ExGetPreviousMode(); if (PreviousMode != KernelMode)
if (PreviousMode == KernelMode)
{
ClientId = ClientMessage->ClientId;
MessageId = ClientMessage->MessageId;
}
else
{ {
_SEH2_TRY _SEH2_TRY
{ {
ProbeForRead(ClientMessage, sizeof(*ClientMessage), sizeof(PVOID)); ProbeForRead(ClientMessage, sizeof(*ClientMessage), sizeof(PVOID));
ClientId = ClientMessage->ClientId; ClientId = ((volatile PORT_MESSAGE*)ClientMessage)->ClientId;
MessageId = ClientMessage->MessageId; MessageId = ((volatile PORT_MESSAGE*)ClientMessage)->MessageId;
} }
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
DPRINT1("Got exception!\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode()); _SEH2_YIELD(return _SEH2_GetExceptionCode());
} }
_SEH2_END; _SEH2_END;
} }
else
{
ClientId = ClientMessage->ClientId;
MessageId = ClientMessage->MessageId;
}
/* Reference the port handle */ /* Reference the port handle */
Status = ObReferenceObjectByHandle(PortHandle, Status = ObReferenceObjectByHandle(PortHandle,

View file

@ -192,7 +192,7 @@ NtReplyPort(IN HANDLE PortHandle,
{ {
NTSTATUS Status; NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
// PORT_MESSAGE CapturedReplyMessage; PORT_MESSAGE CapturedReplyMessage;
PLPCP_PORT_OBJECT Port; PLPCP_PORT_OBJECT Port;
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
PETHREAD Thread = PsGetCurrentThread(), WakeupThread; PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
@ -203,32 +203,35 @@ NtReplyPort(IN HANDLE PortHandle,
PortHandle, PortHandle,
ReplyMessage); ReplyMessage);
if (KeGetPreviousMode() == UserMode) /* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{ {
_SEH2_TRY _SEH2_TRY
{ {
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
/*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE)); CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
ReplyMessage = &CapturedReplyMessage;*/
} }
_SEH2_EXCEPT(ExSystemExceptionFilter()) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
DPRINT1("SEH crash [1]\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode()); _SEH2_YIELD(return _SEH2_GetExceptionCode());
} }
_SEH2_END; _SEH2_END;
} }
else
{
CapturedReplyMessage = *ReplyMessage;
}
/* Validate its length */ /* Validate its length */
if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) > if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)ReplyMessage->u1.s1.TotalLength) (ULONG)CapturedReplyMessage.u1.s1.TotalLength)
{ {
/* Fail */ /* Fail */
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
/* Make sure it has a valid ID */ /* Make sure it has a valid ID */
if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER; if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
/* Get the Port object */ /* Get the Port object */
Status = ObReferenceObjectByHandle(PortHandle, Status = ObReferenceObjectByHandle(PortHandle,
@ -240,9 +243,9 @@ NtReplyPort(IN HANDLE PortHandle,
if (!NT_SUCCESS(Status)) return Status; if (!NT_SUCCESS(Status)) return Status;
/* Validate its length in respect to the port object */ /* Validate its length in respect to the port object */
if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) || if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)ReplyMessage->u1.s1.TotalLength <= ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
(ULONG)ReplyMessage->u1.s1.DataLength)) (ULONG)CapturedReplyMessage.u1.s1.DataLength))
{ {
/* Too large, fail */ /* Too large, fail */
ObDereferenceObject(Port); ObDereferenceObject(Port);
@ -250,7 +253,7 @@ NtReplyPort(IN HANDLE PortHandle,
} }
/* Get the ETHREAD corresponding to it */ /* Get the ETHREAD corresponding to it */
Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId, Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
NULL, NULL,
&WakeupThread); &WakeupThread);
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
@ -274,7 +277,7 @@ NtReplyPort(IN HANDLE PortHandle,
KeAcquireGuardedMutex(&LpcpLock); KeAcquireGuardedMutex(&LpcpLock);
/* Make sure this is the reply the thread is waiting for */ /* Make sure this is the reply the thread is waiting for */
if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) || if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
((LpcpGetMessageFromThread(WakeupThread)) && ((LpcpGetMessageFromThread(WakeupThread)) &&
(LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)-> Request) (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)-> Request)
!= LPC_REQUEST))) != LPC_REQUEST)))
@ -290,7 +293,7 @@ NtReplyPort(IN HANDLE PortHandle,
_SEH2_TRY _SEH2_TRY
{ {
LpcpMoveMessage(&Message->Request, LpcpMoveMessage(&Message->Request,
ReplyMessage, &CapturedReplyMessage,
ReplyMessage + 1, ReplyMessage + 1,
LPC_REPLY, LPC_REPLY,
NULL); NULL);
@ -324,7 +327,7 @@ NtReplyPort(IN HANDLE PortHandle,
/* Check if this is the message the thread had received */ /* Check if this is the message the thread had received */
if ((Thread->LpcReceivedMsgIdValid) && if ((Thread->LpcReceivedMsgIdValid) &&
(Thread->LpcReceivedMessageId == ReplyMessage->MessageId)) (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
{ {
/* Clear this data */ /* Clear this data */
Thread->LpcReceivedMessageId = 0; Thread->LpcReceivedMessageId = 0;
@ -333,9 +336,9 @@ NtReplyPort(IN HANDLE PortHandle,
/* Free any data information */ /* Free any data information */
LpcpFreeDataInfoMessage(Port, LpcpFreeDataInfoMessage(Port,
ReplyMessage->MessageId, CapturedReplyMessage.MessageId,
ReplyMessage->CallbackId, CapturedReplyMessage.CallbackId,
ReplyMessage->ClientId); CapturedReplyMessage.ClientId);
/* Release the lock and release the LPC semaphore to wake up waiters */ /* Release the lock and release the LPC semaphore to wake up waiters */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
@ -362,7 +365,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
{ {
NTSTATUS Status; NTSTATUS Status;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(), WaitMode = PreviousMode;
// PORT_MESSAGE CapturedReplyMessage; PORT_MESSAGE CapturedReplyMessage;
LARGE_INTEGER CapturedTimeout; LARGE_INTEGER CapturedTimeout;
PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL; PLPCP_PORT_OBJECT Port, ReceivePort, ConnectionPort = NULL;
PLPCP_MESSAGE Message; PLPCP_MESSAGE Message;
@ -378,30 +381,29 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
ReceiveMessage, ReceiveMessage,
PortContext); PortContext);
if (KeGetPreviousMode() == UserMode) /* Check if the call comes from user mode */
if (PreviousMode != KernelMode)
{ {
_SEH2_TRY _SEH2_TRY
{ {
if (PortContext != NULL)
ProbeForWritePointer(PortContext);
if (ReplyMessage != NULL) if (ReplyMessage != NULL)
{ {
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG));
/*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE)); CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage;
ReplyMessage = &CapturedReplyMessage;*/
} }
if (Timeout != NULL) if (Timeout != NULL)
{ {
ProbeForReadLargeInteger(Timeout); ProbeForReadLargeInteger(Timeout);
RtlCopyMemory(&CapturedTimeout, Timeout, sizeof(LARGE_INTEGER)); CapturedTimeout = *(volatile LARGE_INTEGER*)Timeout;
Timeout = &CapturedTimeout; Timeout = &CapturedTimeout;
} }
if (PortContext != NULL)
ProbeForWritePointer(PortContext);
} }
_SEH2_EXCEPT(ExSystemExceptionFilter()) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
DPRINT1("SEH crash [1]\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode()); _SEH2_YIELD(return _SEH2_GetExceptionCode());
} }
_SEH2_END; _SEH2_END;
@ -410,21 +412,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
{ {
/* If this is a system thread, then let it page out its stack */ /* If this is a system thread, then let it page out its stack */
if (Thread->SystemThread) WaitMode = UserMode; if (Thread->SystemThread) WaitMode = UserMode;
if (ReplyMessage != NULL)
CapturedReplyMessage = *ReplyMessage;
} }
/* Check if caller has a reply message */ /* Check if caller has a reply message */
if (ReplyMessage) if (ReplyMessage)
{ {
/* Validate its length */ /* Validate its length */
if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) > if (((ULONG)CapturedReplyMessage.u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)ReplyMessage->u1.s1.TotalLength) (ULONG)CapturedReplyMessage.u1.s1.TotalLength)
{ {
/* Fail */ /* Fail */
return STATUS_INVALID_PARAMETER; return STATUS_INVALID_PARAMETER;
} }
/* Make sure it has a valid ID */ /* Make sure it has a valid ID */
if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER; if (!CapturedReplyMessage.MessageId) return STATUS_INVALID_PARAMETER;
} }
/* Get the Port object */ /* Get the Port object */
@ -440,9 +445,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
if (ReplyMessage) if (ReplyMessage)
{ {
/* Validate its length in respect to the port object */ /* Validate its length in respect to the port object */
if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) || if (((ULONG)CapturedReplyMessage.u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)ReplyMessage->u1.s1.TotalLength <= ((ULONG)CapturedReplyMessage.u1.s1.TotalLength <=
(ULONG)ReplyMessage->u1.s1.DataLength)) (ULONG)CapturedReplyMessage.u1.s1.DataLength))
{ {
/* Too large, fail */ /* Too large, fail */
ObDereferenceObject(Port); ObDereferenceObject(Port);
@ -490,7 +495,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
if (ReplyMessage) if (ReplyMessage)
{ {
/* Get the ETHREAD corresponding to it */ /* Get the ETHREAD corresponding to it */
Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId, Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId,
NULL, NULL,
&WakeupThread); &WakeupThread);
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
@ -516,7 +521,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
KeAcquireGuardedMutex(&LpcpLock); KeAcquireGuardedMutex(&LpcpLock);
/* Make sure this is the reply the thread is waiting for */ /* Make sure this is the reply the thread is waiting for */
if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) || if ((WakeupThread->LpcReplyMessageId != CapturedReplyMessage.MessageId) ||
((LpcpGetMessageFromThread(WakeupThread)) && ((LpcpGetMessageFromThread(WakeupThread)) &&
(LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->Request) (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->Request)
!= LPC_REQUEST))) != LPC_REQUEST)))
@ -530,11 +535,24 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
} }
/* Copy the message */ /* Copy the message */
LpcpMoveMessage(&Message->Request, _SEH2_TRY
ReplyMessage, {
ReplyMessage + 1, LpcpMoveMessage(&Message->Request,
LPC_REPLY, &CapturedReplyMessage,
NULL); ReplyMessage + 1,
LPC_REPLY,
NULL);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Cleanup and return the exception code */
LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(WakeupThread);
ObDereferenceObject(Port);
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
/* Reference the thread while we use it */ /* Reference the thread while we use it */
ObReferenceObject(WakeupThread); ObReferenceObject(WakeupThread);
@ -555,7 +573,7 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
/* Check if this is the message the thread had received */ /* Check if this is the message the thread had received */
if ((Thread->LpcReceivedMsgIdValid) && if ((Thread->LpcReceivedMsgIdValid) &&
(Thread->LpcReceivedMessageId == ReplyMessage->MessageId)) (Thread->LpcReceivedMessageId == CapturedReplyMessage.MessageId))
{ {
/* Clear this data */ /* Clear this data */
Thread->LpcReceivedMessageId = 0; Thread->LpcReceivedMessageId = 0;
@ -564,9 +582,9 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
/* Free any data information */ /* Free any data information */
LpcpFreeDataInfoMessage(Port, LpcpFreeDataInfoMessage(Port,
ReplyMessage->MessageId, CapturedReplyMessage.MessageId,
ReplyMessage->CallbackId, CapturedReplyMessage.CallbackId,
ReplyMessage->ClientId); CapturedReplyMessage.ClientId);
/* Release the lock and release the LPC semaphore to wake up waiters */ /* Release the lock and release the LPC semaphore to wake up waiters */
KeReleaseGuardedMutex(&LpcpLock); KeReleaseGuardedMutex(&LpcpLock);
@ -688,9 +706,8 @@ NtReplyWaitReceivePortEx(IN HANDLE PortHandle,
ASSERT(FALSE); ASSERT(FALSE);
} }
} }
_SEH2_EXCEPT(ExSystemExceptionFilter()) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
DPRINT1("SEH crash [2]\n");
Status = _SEH2_GetExceptionCode(); Status = _SEH2_GetExceptionCode();
} }
_SEH2_END; _SEH2_END;
@ -771,26 +788,24 @@ LpcpCopyRequestData(
PAGED_CODE(); PAGED_CODE();
/* Check the previous mode */ /* Check if the call comes from user mode */
PreviousMode = ExGetPreviousMode(); if (PreviousMode != KernelMode)
if (PreviousMode == KernelMode)
{
CapturedMessage = *Message;
}
else
{ {
_SEH2_TRY _SEH2_TRY
{ {
ProbeForRead(Message, sizeof(*Message), sizeof(PVOID)); ProbeForRead(Message, sizeof(*Message), sizeof(PVOID));
CapturedMessage = *Message; CapturedMessage = *(volatile PORT_MESSAGE*)Message;
} }
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
DPRINT1("Got exception!\n");
_SEH2_YIELD(return _SEH2_GetExceptionCode()); _SEH2_YIELD(return _SEH2_GetExceptionCode());
} }
_SEH2_END; _SEH2_END;
} }
else
{
CapturedMessage = *Message;
}
/* Make sure there is any data to copy */ /* Make sure there is any data to copy */
if (CapturedMessage.u2.s2.DataInfoOffset == 0) if (CapturedMessage.u2.s2.DataInfoOffset == 0)

View file

@ -461,9 +461,8 @@ NtRequestPort(IN HANDLE PortHandle,
_SEH2_TRY _SEH2_TRY
{ {
/* Probe and capture the LpcRequest */ /* Probe and capture the LpcRequest */
ProbeForRead(LpcRequest, sizeof(PORT_MESSAGE), sizeof(ULONG)); ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG)); CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
CapturedLpcRequest = *LpcRequest;
} }
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{ {
@ -523,7 +522,7 @@ NtRequestPort(IN HANDLE PortHandle,
{ {
/* Copy it */ /* Copy it */
LpcpMoveMessage(&Message->Request, LpcpMoveMessage(&Message->Request,
LpcRequest, &CapturedLpcRequest,
LpcRequest + 1, LpcRequest + 1,
MessageType, MessageType,
&Thread->Cid); &Thread->Cid);
@ -725,10 +724,9 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle,
{ {
_SEH2_TRY _SEH2_TRY
{ {
/* Probe the full request message and copy the base structure */ /* Probe and capture the LpcRequest */
ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG)); ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG));
ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG)); CapturedLpcRequest = *(volatile PORT_MESSAGE*)LpcRequest;
CapturedLpcRequest = *LpcRequest;
/* Probe the reply message for write */ /* Probe the reply message for write */
ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG)); ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG));