diff --git a/reactos/ntoskrnl/lpc/reply.c b/reactos/ntoskrnl/lpc/reply.c index 0233ea10cc3..377313e7e55 100644 --- a/reactos/ntoskrnl/lpc/reply.c +++ b/reactos/ntoskrnl/lpc/reply.c @@ -141,10 +141,172 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination, NTSTATUS NTAPI NtReplyPort(IN HANDLE PortHandle, - IN PPORT_MESSAGE LpcReply) + IN PPORT_MESSAGE ReplyMessage) { - UNIMPLEMENTED; - return STATUS_NOT_IMPLEMENTED; + PLPCP_PORT_OBJECT Port, ConnectionPort = NULL; + KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); + NTSTATUS Status; + PLPCP_MESSAGE Message; + PETHREAD Thread = PsGetCurrentThread(), WakeupThread; + //PORT_MESSAGE CapturedReplyMessage; + + PAGED_CODE(); + LPCTRACE(LPC_REPLY_DEBUG, + "Handle: %lx. Message: %p.\n", + PortHandle, + ReplyMessage); + + if (KeGetPreviousMode() == UserMode) + { + _SEH2_TRY + { + if (ReplyMessage != NULL) + { + ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG)); + /*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE)); + ReplyMessage = &CapturedReplyMessage;*/ + } + } + _SEH2_EXCEPT(ExSystemExceptionFilter()) + { + DPRINT1("SEH crash [1]\n"); + DbgBreakPoint(); + _SEH2_YIELD(return _SEH2_GetExceptionCode()); + } + _SEH2_END; + } + + /* Validate its length */ + if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) > + (ULONG)ReplyMessage->u1.s1.TotalLength) + { + /* Fail */ + return STATUS_INVALID_PARAMETER; + } + + /* Make sure it has a valid ID */ + if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER; + + /* Get the Port object */ + Status = ObReferenceObjectByHandle(PortHandle, + 0, + LpcPortObjectType, + PreviousMode, + (PVOID*)&Port, + NULL); + if (!NT_SUCCESS(Status)) return Status; + + /* Validate its length in respect to the port object */ + if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) || + ((ULONG)ReplyMessage->u1.s1.TotalLength <= + (ULONG)ReplyMessage->u1.s1.DataLength)) + { + /* Too large, fail */ + ObDereferenceObject(Port); + return STATUS_PORT_MESSAGE_TOO_LONG; + } + + /* Get the ETHREAD corresponding to it */ + Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId, + NULL, + &WakeupThread); + if (!NT_SUCCESS(Status)) + { + /* No thread found, fail */ + ObDereferenceObject(Port); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); + return Status; + } + + /* Allocate a message from the port zone */ + Message = LpcpAllocateFromPortZone(); + if (!Message) + { + /* Fail if we couldn't allocate a message */ + if (ConnectionPort) ObDereferenceObject(ConnectionPort); + ObDereferenceObject(WakeupThread); + ObDereferenceObject(Port); + return STATUS_NO_MEMORY; + } + + /* Keep the lock acquired */ + KeAcquireGuardedMutex(&LpcpLock); + + /* Make sure this is the reply the thread is waiting for */ + if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) || + ((LpcpGetMessageFromThread(WakeupThread)) && + (LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)-> + Request) != LPC_REQUEST))) + { + /* It isn't, fail */ + LpcpFreeToPortZone(Message, 3); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); + ObDereferenceObject(WakeupThread); + ObDereferenceObject(Port); + return STATUS_REPLY_MESSAGE_MISMATCH; + } + + /* Copy the message */ + _SEH2_TRY + { + LpcpMoveMessage(&Message->Request, + ReplyMessage, + ReplyMessage + 1, + LPC_REPLY, + NULL); + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + /* Fail */ + LpcpFreeToPortZone(Message, 3); + ObDereferenceObject(WakeupThread); + ObDereferenceObject(Port); + _SEH2_YIELD(return _SEH2_GetExceptionCode()); + } + _SEH2_END; + + /* Reference the thread while we use it */ + ObReferenceObject(WakeupThread); + Message->RepliedToThread = WakeupThread; + + /* Set this as the reply message */ + WakeupThread->LpcReplyMessageId = 0; + WakeupThread->LpcReplyMessage = (PVOID)Message; + + /* Check if we have messages on the reply chain */ + if (!(WakeupThread->LpcExitThreadCalled) && + !(IsListEmpty(&WakeupThread->LpcReplyChain))) + { + /* Remove us from it and reinitialize it */ + RemoveEntryList(&WakeupThread->LpcReplyChain); + InitializeListHead(&WakeupThread->LpcReplyChain); + } + + /* Check if this is the message the thread had received */ + if ((Thread->LpcReceivedMsgIdValid) && + (Thread->LpcReceivedMessageId == ReplyMessage->MessageId)) + { + /* Clear this data */ + Thread->LpcReceivedMessageId = 0; + Thread->LpcReceivedMsgIdValid = FALSE; + } + + /* Free any data information */ + LpcpFreeDataInfoMessage(Port, + ReplyMessage->MessageId, + ReplyMessage->CallbackId, + ReplyMessage->ClientId); + + /* Release the lock and release the LPC semaphore to wake up waiters */ + KeReleaseGuardedMutex(&LpcpLock); + LpcpCompleteWait(&WakeupThread->LpcReplySemaphore); + + /* Now we can let go of the thread */ + ObDereferenceObject(WakeupThread); + + /* Dereference port object */ + ObDereferenceObject(Port); + return Status; } /* diff --git a/reactos/ntoskrnl/lpc/send.c b/reactos/ntoskrnl/lpc/send.c index dad5a35c865..6980ee445f7 100644 --- a/reactos/ntoskrnl/lpc/send.c +++ b/reactos/ntoskrnl/lpc/send.c @@ -442,10 +442,195 @@ LpcRequestWaitReplyPort(IN PVOID PortObject, NTSTATUS NTAPI NtRequestPort(IN HANDLE PortHandle, - IN PPORT_MESSAGE LpcMessage) + IN PPORT_MESSAGE LpcRequest) { - UNIMPLEMENTED; - return STATUS_NOT_IMPLEMENTED; + PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL; + KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); + NTSTATUS Status; + PLPCP_MESSAGE Message; + PETHREAD Thread = PsGetCurrentThread(); + + PKSEMAPHORE Semaphore; + ULONG MessageType; + PAGED_CODE(); + LPCTRACE(LPC_SEND_DEBUG, + "Handle: %lx. Message: %p. Type: %lx\n", + PortHandle, + LpcRequest, + LpcpGetMessageType(LpcRequest)); + + /* Get the message type */ + MessageType = LpcRequest->u2.s2.Type | LPC_DATAGRAM; + + /* Can't have data information on this type of call */ + if (LpcRequest->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER; + + /* Validate the length */ + if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) > + (ULONG)LpcRequest->u1.s1.TotalLength) + { + /* Fail */ + return STATUS_INVALID_PARAMETER; + } + + /* Reference the object */ + Status = ObReferenceObjectByHandle(PortHandle, + 0, + LpcPortObjectType, + PreviousMode, + (PVOID*)&Port, + NULL); + if (!NT_SUCCESS(Status)) return Status; + + /* Validate the message length */ + if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) || + ((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength)) + { + /* Fail */ + ObDereferenceObject(Port); + return STATUS_PORT_MESSAGE_TOO_LONG; + } + + /* Allocate a message from the port zone */ + Message = LpcpAllocateFromPortZone(); + if (!Message) + { + /* Fail if we couldn't allocate a message */ + ObDereferenceObject(Port); + return STATUS_NO_MEMORY; + } + + /* No callback, just copy the message */ + _SEH2_TRY + { + /* Copy it */ + LpcpMoveMessage(&Message->Request, + LpcRequest, + LpcRequest + 1, + MessageType, + &Thread->Cid); + } + _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) + { + /* Fail */ + LpcpFreeToPortZone(Message, 0); + ObDereferenceObject(Port); + _SEH2_YIELD(return _SEH2_GetExceptionCode()); + } + _SEH2_END; + + /* Acquire the LPC lock */ + KeAcquireGuardedMutex(&LpcpLock); + + /* Right now clear the port context */ + Message->PortContext = NULL; + + /* Check if this is a not connection port */ + if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT) + { + /* We want the connected port */ + QueuePort = Port->ConnectedPort; + if (!QueuePort) + { + /* We have no connected port, fail */ + LpcpFreeToPortZone(Message, 3); + ObDereferenceObject(Port); + return STATUS_PORT_DISCONNECTED; + } + + /* Check if this is a communication port */ + if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT) + { + /* Copy the port context and use the connection port */ + Message->PortContext = QueuePort->PortContext; + ConnectionPort = QueuePort = Port->ConnectionPort; + if (!ConnectionPort) + { + /* Fail */ + LpcpFreeToPortZone(Message, 3); + ObDereferenceObject(Port); + return STATUS_PORT_DISCONNECTED; + } + } + else if ((Port->Flags & LPCP_PORT_TYPE_MASK) != + LPCP_COMMUNICATION_PORT) + { + /* Use the connection port for anything but communication ports */ + ConnectionPort = QueuePort = Port->ConnectionPort; + if (!ConnectionPort) + { + /* Fail */ + LpcpFreeToPortZone(Message, 3); + ObDereferenceObject(Port); + return STATUS_PORT_DISCONNECTED; + } + } + + /* Reference the connection port if it exists */ + if (ConnectionPort) ObReferenceObject(ConnectionPort); + } + else + { + /* Otherwise, for a connection port, use the same port object */ + QueuePort = Port; + } + + /* Reference QueuePort if we have it */ + if (QueuePort && ObReferenceObjectSafe(QueuePort)) + { + /* Set sender's port */ + Message->SenderPort = Port; + + /* Generate the Message ID and set it */ + Message->Request.MessageId = LpcpNextMessageId++; + if (!LpcpNextMessageId) LpcpNextMessageId = 1; + Message->Request.CallbackId = 0; + + /* No Message ID for the thread */ + PsGetCurrentThread()->LpcReplyMessageId = 0; + + /* Insert the message in our chain */ + InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry); + + /* Release the lock and get the semaphore we'll use later */ + KeEnterCriticalRegion(); + KeReleaseGuardedMutex(&LpcpLock); + + /* Now release the semaphore */ + Semaphore = QueuePort->MsgQueue.Semaphore; + LpcpCompleteWait(Semaphore); + + /* If this is a waitable port, wake it up */ + if (QueuePort->Flags & LPCP_WAITABLE_PORT) + { + /* Wake it */ + KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE); + } + + KeLeaveCriticalRegion(); + + /* Dereference objects */ + if (ConnectionPort) ObDereferenceObject(ConnectionPort); + ObDereferenceObject(QueuePort); + ObDereferenceObject(Port); + LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message); + return STATUS_SUCCESS; + } + + Status = STATUS_PORT_DISCONNECTED; + + /* All done with a failure*/ + LPCTRACE(LPC_SEND_DEBUG, + "Port: %p. Status: %p\n", + Port, + Status); + + /* The wait failed, free the message */ + if (Message) LpcpFreeToPortZone(Message, 3); + + ObDereferenceObject(Port); + if (ConnectionPort) ObDereferenceObject(ConnectionPort); + return Status; } /*