diff --git a/reactos/ntoskrnl/include/internal/lpc.h b/reactos/ntoskrnl/include/internal/lpc.h index c6d3a14d781..e862a4f2476 100644 --- a/reactos/ntoskrnl/include/internal/lpc.h +++ b/reactos/ntoskrnl/include/internal/lpc.h @@ -63,6 +63,18 @@ #define LPCP_LOCK_HELD 1 #define LPCP_LOCK_RELEASE 2 + +typedef struct _LPCP_DATA_INFO +{ + ULONG NumberOfEntries; + struct + { + PVOID BaseAddress; + ULONG DataLength; + } Entries[1]; +} LPCP_DATA_INFO, *PLPCP_DATA_INFO; + + // // Internal Port Management // @@ -131,6 +143,13 @@ LpcInitSystem( VOID ); +BOOLEAN +NTAPI +LpcpValidateClientPort( + PETHREAD ClientThread, + PLPCP_PORT_OBJECT Port); + + // // Global data inside the Process Manager // diff --git a/reactos/ntoskrnl/include/internal/lpc_x.h b/reactos/ntoskrnl/include/internal/lpc_x.h index 816f979a409..0d04988305e 100644 --- a/reactos/ntoskrnl/include/internal/lpc_x.h +++ b/reactos/ntoskrnl/include/internal/lpc_x.h @@ -164,3 +164,10 @@ LpcpSetPortToThread(IN PETHREAD Thread, Thread->LpcWaitingOnPort = (PVOID)(((ULONG_PTR)Port) | LPCP_THREAD_FLAG_IS_PORT); } + +FORCEINLINE +PLPCP_DATA_INFO +LpcpGetDataInfoFromMessage(PPORT_MESSAGE Message) +{ + return (PLPCP_DATA_INFO)((PUCHAR)Message + Message->u2.s2.DataInfoOffset); +} diff --git a/reactos/ntoskrnl/include/internal/mm.h b/reactos/ntoskrnl/include/internal/mm.h index 4e9ddc0f52b..6a894941251 100644 --- a/reactos/ntoskrnl/include/internal/mm.h +++ b/reactos/ntoskrnl/include/internal/mm.h @@ -1798,3 +1798,17 @@ VOID NTAPI MmSetSessionLocaleId( _In_ LCID LocaleId); + + +/* virtual.c *****************************************************************/ + +NTSTATUS +NTAPI +MmCopyVirtualMemory(IN PEPROCESS SourceProcess, + IN PVOID SourceAddress, + IN PEPROCESS TargetProcess, + OUT PVOID TargetAddress, + IN SIZE_T BufferSize, + IN KPROCESSOR_MODE PreviousMode, + OUT PSIZE_T ReturnSize); + diff --git a/reactos/ntoskrnl/lpc/reply.c b/reactos/ntoskrnl/lpc/reply.c index 744c9105232..d07152f83af 100644 --- a/reactos/ntoskrnl/lpc/reply.c +++ b/reactos/ntoskrnl/lpc/reply.c @@ -90,6 +90,49 @@ LpcpSaveDataInfoMessage(IN PLPCP_PORT_OBJECT Port, if (!LockHeld) KeReleaseGuardedMutex(&LpcpLock); } +PLPCP_MESSAGE +NTAPI +LpcpFindDataInfoMessage( + IN PLPCP_PORT_OBJECT Port, + IN ULONG MessageId, + IN LPC_CLIENT_ID ClientId) +{ + PLPCP_MESSAGE Message; + PLIST_ENTRY ListEntry; + PAGED_CODE(); + + /* Check if the port we want is the connection port */ + if ((Port->Flags & LPCP_PORT_TYPE_MASK) > LPCP_UNCONNECTED_PORT) + { + /* Use it */ + Port = Port->ConnectionPort; + if (!Port) + { + /* Return NULL */ + return NULL; + } + } + + /* Loop all entries in the list */ + for (ListEntry = Port->LpcDataInfoChainHead.Flink; + ListEntry != &Port->LpcDataInfoChainHead; + ListEntry = ListEntry->Flink) + { + Message = CONTAINING_RECORD(ListEntry, LPCP_MESSAGE, Entry); + + /* Check if this is the desired message */ + if ((Message->Request.MessageId == MessageId) && + (Message->Request.ClientId.UniqueProcess == ClientId.UniqueProcess) && + (Message->Request.ClientId.UniqueThread == ClientId.UniqueThread)) + { + /* It is, return it */ + return Message; + } + } + + return NULL; +} + VOID NTAPI LpcpMoveMessage(IN PPORT_MESSAGE Destination, @@ -132,7 +175,7 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination, /* Copy the Message Data */ RtlCopyMemory(Destination + 1, Data, - ((Destination->u1.Length & 0xFFFF) + 3) &~3); + ALIGN_UP_BY(Destination->u1.s1.DataLength, sizeof(ULONG))); } /* PUBLIC FUNCTIONS **********************************************************/ @@ -710,6 +753,199 @@ NtReplyWaitReplyPort(IN HANDLE PortHandle, return STATUS_NOT_IMPLEMENTED; } +NTSTATUS +NTAPI +LpcpCopyRequestData( + IN BOOLEAN Write, + IN HANDLE PortHandle, + IN PPORT_MESSAGE Message, + IN ULONG Index, + IN PVOID Buffer, + IN ULONG BufferLength, + OUT PULONG Returnlength) +{ + KPROCESSOR_MODE PreviousMode; + PORT_MESSAGE CapturedMessage; + PLPCP_PORT_OBJECT Port = NULL; + PETHREAD ClientThread = NULL; + ULONG LocalReturnlength; + PLPCP_MESSAGE InfoMessage; + PLPCP_DATA_INFO DataInfo; + PVOID DataInfoBaseAddress; + NTSTATUS Status; + PAGED_CODE(); + + /* Check the previous mode */ + PreviousMode = ExGetPreviousMode(); + if (PreviousMode == KernelMode) + { + CapturedMessage = *Message; + } + else + { + _SEH2_TRY + { + ProbeForRead(Message, sizeof(*Message), sizeof(PVOID)); + CapturedMessage = *Message; + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + DPRINT1("Got exception!\n"); + return _SEH2_GetExceptionCode(); + } + _SEH2_END; + } + + /* Make sure there is any data to copy */ + if (CapturedMessage.u2.s2.DataInfoOffset == 0) + { + return STATUS_INVALID_PARAMETER; + } + + /* Reference the port handle */ + Status = ObReferenceObjectByHandle(PortHandle, + PORT_ALL_ACCESS, + LpcPortObjectType, + PreviousMode, + (PVOID*)&Port, + NULL); + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to reference port handle: 0x%ls\n", Status); + return Status; + } + + /* Look up the client thread */ + Status = PsLookupProcessThreadByCid(&CapturedMessage.ClientId, + NULL, + &ClientThread); + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to lookup client thread for [0x%lx:0x%lx]: 0x%ls\n", + CapturedMessage.ClientId.UniqueProcess, + CapturedMessage.ClientId.UniqueThread, Status); + goto Cleanup; + } + + /* Acquire the global LPC lock */ + KeAcquireGuardedMutex(&LpcpLock); + + /* Check for message id mismatch */ + if ((ClientThread->LpcReplyMessageId != CapturedMessage.MessageId) || + (CapturedMessage.MessageId == 0)) + { + DPRINT1("LpcReplyMessageId mismatch: 0x%lx/0x%lx.\n", + ClientThread->LpcReplyMessageId, CapturedMessage.MessageId); + Status = STATUS_REPLY_MESSAGE_MISMATCH; + goto CleanupWithLock; + } + + /* Validate the port */ + if (!LpcpValidateClientPort(ClientThread, Port)) + { + DPRINT1("LpcpValidateClientPort failed\n"); + Status = STATUS_REPLY_MESSAGE_MISMATCH; + goto CleanupWithLock; + } + + /* Find the message with the data */ + InfoMessage = LpcpFindDataInfoMessage(Port, + CapturedMessage.MessageId, + CapturedMessage.ClientId); + if (InfoMessage == NULL) + { + DPRINT1("LpcpFindDataInfoMessage failed\n"); + Status = STATUS_INVALID_PARAMETER; + goto CleanupWithLock; + } + + /* Get the data info */ + DataInfo = LpcpGetDataInfoFromMessage(&InfoMessage->Request); + + /* Check if the index is within bounds */ + if (Index >= DataInfo->NumberOfEntries) + { + DPRINT1("Message data index %lu out of bounds (%lu in msg)\n", + Index, DataInfo->NumberOfEntries); + Status = STATUS_INVALID_PARAMETER; + goto CleanupWithLock; + } + + /* Check if the caller wants to read/write more data than expected */ + if (BufferLength > DataInfo->Entries[Index].DataLength) + { + DPRINT1("Trying to read more data (%lu) than available (%lu)\n", + BufferLength, DataInfo->Entries[Index].DataLength); + Status = STATUS_INVALID_PARAMETER; + goto CleanupWithLock; + } + + /* Get the data pointer */ + DataInfoBaseAddress = DataInfo->Entries[Index].BaseAddress; + + /* Release the lock */ + KeReleaseGuardedMutex(&LpcpLock); + + if (Write) + { + /* Copy data from the caller to the message sender */ + Status = MmCopyVirtualMemory(PsGetCurrentProcess(), + Buffer, + ClientThread->ThreadsProcess, + DataInfoBaseAddress, + BufferLength, + PreviousMode, + &LocalReturnlength); + } + else + { + /* Copy data from the message sender to the caller */ + Status = MmCopyVirtualMemory(ClientThread->ThreadsProcess, + DataInfoBaseAddress, + PsGetCurrentProcess(), + Buffer, + BufferLength, + PreviousMode, + &LocalReturnlength); + } + + if (!NT_SUCCESS(Status)) + { + DPRINT1("MmCopyVirtualMemory failed: 0x%ls\n", Status); + goto Cleanup; + } + + /* Check if the caller asked to return the copied length */ + if (Returnlength != NULL) + { + _SEH2_TRY + { + *Returnlength = LocalReturnlength; + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + /* Ignore */ + DPRINT1("Exception writing Returnlength, ignoring\n"); + } + _SEH2_END; + } + +Cleanup: + + if (ClientThread != NULL) + ObDereferenceObject(ClientThread); + + ObDereferenceObject(Port); + + return Status; + +CleanupWithLock: + + /* Release the lock */ + KeReleaseGuardedMutex(&LpcpLock); + goto Cleanup; +} + /* * @unimplemented */ @@ -720,10 +956,16 @@ NtReadRequestData(IN HANDLE PortHandle, IN ULONG Index, IN PVOID Buffer, IN ULONG BufferLength, - OUT PULONG Returnlength) + OUT PULONG ReturnLength) { - UNIMPLEMENTED; - return STATUS_NOT_IMPLEMENTED; + /* Call the internal function */ + return LpcpCopyRequestData(FALSE, + PortHandle, + Message, + Index, + Buffer, + BufferLength, + ReturnLength); } /* @@ -738,8 +980,14 @@ NtWriteRequestData(IN HANDLE PortHandle, IN ULONG BufferLength, OUT PULONG ReturnLength) { - UNIMPLEMENTED; - return STATUS_NOT_IMPLEMENTED; + /* Call the internal function */ + return LpcpCopyRequestData(TRUE, + PortHandle, + Message, + Index, + Buffer, + BufferLength, + ReturnLength); } /* EOF */ diff --git a/reactos/ntoskrnl/lpc/send.c b/reactos/ntoskrnl/lpc/send.c index 7a3cd451929..a9aad6afb2d 100644 --- a/reactos/ntoskrnl/lpc/send.c +++ b/reactos/ntoskrnl/lpc/send.c @@ -207,34 +207,34 @@ LpcRequestWaitReplyPort(IN PVOID PortObject, { /* No type */ case 0: - + /* Assume LPC request */ MessageType = LPC_REQUEST; break; - + /* LPC request callback */ case LPC_REQUEST: - + /* This is a callback */ Callback = TRUE; break; - + /* Anything else */ case LPC_CLIENT_DIED: case LPC_PORT_CLOSED: case LPC_EXCEPTION: case LPC_DEBUG_EVENT: case LPC_ERROR_EVENT: - + /* Nothing to do */ break; - + default: - + /* Invalid message type */ return STATUS_INVALID_PARAMETER; } - + /* Set the request type */ LpcRequest->u2.s2.Type = MessageType; @@ -401,10 +401,10 @@ LpcRequestWaitReplyPort(IN PVOID PortObject, (&Message->Request) + 1, 0, NULL); - + /* Acquire the lock */ KeAcquireGuardedMutex(&LpcpLock); - + /* Check if we replied to a thread */ if (Message->RepliedToThread) { @@ -633,6 +633,47 @@ NtRequestPort(IN HANDLE PortHandle, return Status; } +NTSTATUS +NTAPI +LpcpVerifyMessageDataInfo( + _In_ PPORT_MESSAGE Message, + _Out_ PULONG NumberOfDataEntries) +{ + PLPCP_DATA_INFO DataInfo; + PUCHAR EndOfEntries; + + /* Check if we have no data info at all */ + if (Message->u2.s2.DataInfoOffset == 0) + { + *NumberOfDataEntries = 0; + return STATUS_SUCCESS; + } + + /* Make sure the data info structure is within the message */ + if (((ULONG)Message->u1.s1.TotalLength < + sizeof(PORT_MESSAGE) + sizeof(LPCP_DATA_INFO)) || + ((ULONG)Message->u2.s2.DataInfoOffset < sizeof(PORT_MESSAGE)) || + ((ULONG)Message->u2.s2.DataInfoOffset > + ((ULONG)Message->u1.s1.TotalLength - sizeof(LPCP_DATA_INFO)))) + { + return STATUS_INVALID_PARAMETER; + } + + /* Get a pointer to the data info */ + DataInfo = LpcpGetDataInfoFromMessage(Message); + + /* Make sure the full data info with all entries is within the message */ + EndOfEntries = (PUCHAR)&DataInfo->Entries[DataInfo->NumberOfEntries]; + if ((EndOfEntries > ((PUCHAR)Message + (ULONG)Message->u1.s1.TotalLength)) || + (EndOfEntries < (PUCHAR)Message)) + { + return STATUS_INVALID_PARAMETER; + } + + *NumberOfDataEntries = DataInfo->NumberOfEntries; + return STATUS_SUCCESS; +} + /* * @implemented */ @@ -642,6 +683,8 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, IN PPORT_MESSAGE LpcRequest, IN OUT PPORT_MESSAGE LpcReply) { + PORT_MESSAGE LocalLpcRequest; + ULONG NumberOfDataEntries; PLPCP_PORT_OBJECT Port, QueuePort, ReplyPort, ConnectionPort = NULL; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); NTSTATUS Status; @@ -650,6 +693,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, BOOLEAN Callback; PKSEMAPHORE Semaphore; ULONG MessageType; + PLPCP_DATA_INFO DataInfo; PAGED_CODE(); LPCTRACE(LPC_SEND_DEBUG, "Handle: %p. Messages: %p/%p. Type: %lx\n", @@ -661,32 +705,78 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, /* Check if the thread is dying */ if (Thread->LpcExitThreadCalled) return STATUS_THREAD_IS_TERMINATING; + /* Check for user mode access */ + if (PreviousMode != KernelMode) + { + _SEH2_TRY + { + /* Probe the full request message and copy the base structure */ + ProbeForRead(LpcRequest, sizeof(*LpcRequest), sizeof(ULONG)); + ProbeForRead(LpcRequest, LpcRequest->u1.s1.TotalLength, sizeof(ULONG)); + LocalLpcRequest = *LpcRequest; + + /* Probe the reply message for write */ + ProbeForWrite(LpcReply, sizeof(*LpcReply), sizeof(ULONG)); + + /* Make sure the data entries in the request message are valid */ + Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries); + if (!NT_SUCCESS(Status)) + { + DPRINT1("LpcpVerifyMessageDataInfo failed\n"); + return Status; + } + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + DPRINT1("Got exception\n"); + return _SEH2_GetExceptionCode(); + } + _SEH2_END; + } + else + { + LocalLpcRequest = *LpcRequest; + Status = LpcpVerifyMessageDataInfo(LpcRequest, &NumberOfDataEntries); + if (!NT_SUCCESS(Status)) + { + DPRINT1("LpcpVerifyMessageDataInfo failed\n"); + return Status; + } + } + /* Check if this is an LPC Request */ - if (LpcpGetMessageType(LpcRequest) == LPC_REQUEST) + if (LpcpGetMessageType(&LocalLpcRequest) == LPC_REQUEST) { /* Then it's a callback */ Callback = TRUE; } - else if (LpcpGetMessageType(LpcRequest)) + else if (LpcpGetMessageType(&LocalLpcRequest)) { /* This is a not kernel-mode message */ + DPRINT1("Not a kernel-mode message!\n"); return STATUS_INVALID_PARAMETER; } else { /* This is a kernel-mode message without a callback */ - LpcRequest->u2.s2.Type |= LPC_REQUEST; + LocalLpcRequest.u2.s2.Type |= LPC_REQUEST; Callback = FALSE; } /* Get the message type */ - MessageType = LpcRequest->u2.s2.Type; + MessageType = LocalLpcRequest.u2.s2.Type; + + /* Due to the above probe, we know that TotalLength is positive */ + NT_ASSERT(LocalLpcRequest.u1.s1.TotalLength >= 0); /* Validate the length */ - if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) > - (ULONG)LpcRequest->u1.s1.TotalLength) + if ((((ULONG)(USHORT)LocalLpcRequest.u1.s1.DataLength + sizeof(PORT_MESSAGE)) > + (ULONG)LocalLpcRequest.u1.s1.TotalLength)) { /* Fail */ + DPRINT1("Invalid message length: %u, %u\n", + LocalLpcRequest.u1.s1.DataLength, + LocalLpcRequest.u1.s1.TotalLength); return STATUS_INVALID_PARAMETER; } @@ -700,10 +790,13 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, if (!NT_SUCCESS(Status)) return Status; /* Validate the message length */ - if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) || - ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength)) + if (((ULONG)LocalLpcRequest.u1.s1.TotalLength > Port->MaxMessageLength) || + ((ULONG)LocalLpcRequest.u1.s1.TotalLength <= (ULONG)LocalLpcRequest.u1.s1.DataLength)) { /* Fail */ + DPRINT1("Invalid message length: %u, %u\n", + LocalLpcRequest.u1.s1.DataLength, + LocalLpcRequest.u1.s1.TotalLength); ObDereferenceObject(Port); return STATUS_PORT_MESSAGE_TOO_LONG; } @@ -713,6 +806,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, if (!Message) { /* Fail if we couldn't allocate a message */ + DPRINT1("Failed to allocate a message!\n"); ObDereferenceObject(Port); return STATUS_NO_MEMORY; } @@ -729,6 +823,22 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, /* No callback, just copy the message */ _SEH2_TRY { + /* Check if we have data info entries */ + if (LpcRequest->u2.s2.DataInfoOffset != 0) + { + /* Get the data info and check if the number of entries matches + what we expect */ + DataInfo = LpcpGetDataInfoFromMessage(LpcRequest); + if (DataInfo->NumberOfEntries != NumberOfDataEntries) + { + LpcpFreeToPortZone(Message, 0); + ObDereferenceObject(Port); + DPRINT1("NumberOfEntries has changed: %u, %u\n", + DataInfo->NumberOfEntries, NumberOfDataEntries); + return STATUS_INVALID_PARAMETER; + } + } + /* Copy it */ LpcpMoveMessage(&Message->Request, LpcRequest, @@ -739,6 +849,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { /* Fail */ + DPRINT1("Got exception!\n"); LpcpFreeToPortZone(Message, 0); ObDereferenceObject(Port); _SEH2_YIELD(return _SEH2_GetExceptionCode()); @@ -759,6 +870,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, if (!QueuePort) { /* We have no connected port, fail */ + DPRINT1("No connected port\n"); LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE); ObDereferenceObject(Port); return STATUS_PORT_DISCONNECTED; @@ -767,28 +879,21 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, /* This will be the rundown port */ ReplyPort = QueuePort; - /* Check if this is a communication port */ + /* Check if this is a client port */ if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT) { - /* Copy the port context and use the connection port */ + /* Copy the port context */ Message->PortContext = QueuePort->PortContext; - ConnectionPort = QueuePort = Port->ConnectionPort; - if (!ConnectionPort) - { - /* Fail */ - LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE); - ObDereferenceObject(Port); - return STATUS_PORT_DISCONNECTED; - } } - else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != - LPCP_COMMUNICATION_PORT) + + if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT) { /* Use the connection port for anything but communication ports */ ConnectionPort = QueuePort = Port->ConnectionPort; if (!ConnectionPort) { /* Fail */ + DPRINT1("No connection port\n"); LpcpFreeToPortZone(Message, LPCP_LOCK_HELD | LPCP_LOCK_RELEASE); ObDereferenceObject(Port); return STATUS_PORT_DISCONNECTED; @@ -883,6 +988,7 @@ NtRequestWaitReplyPort(IN HANDLE PortHandle, } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { + DPRINT1("Got exception!\n"); Status = _SEH2_GetExceptionCode(); } _SEH2_END;