/* * PROJECT: ReactOS Kernel * LICENSE: GPL - See COPYING in the top level directory * FILE: ntoskrnl/lpc/complete.c * PURPOSE: Local Procedure Call: Connection Completion * PROGRAMMERS: Alex Ionescu (alex.ionescu@reactos.org) */ /* INCLUDES ******************************************************************/ #include #define NDEBUG #include /* PRIVATE FUNCTIONS *********************************************************/ VOID NTAPI LpcpPrepareToWakeClient(IN PETHREAD Thread) { PAGED_CODE(); /* Make sure the thread isn't dying and it has a valid chain */ if (!(Thread->LpcExitThreadCalled) && !(IsListEmpty(&Thread->LpcReplyChain))) { /* Remove it from the list and reinitialize it */ RemoveEntryList(&Thread->LpcReplyChain); InitializeListHead(&Thread->LpcReplyChain); } } /* PUBLIC FUNCTIONS **********************************************************/ /* * @implemented */ NTSTATUS NTAPI NtAcceptConnectPort(OUT PHANDLE PortHandle, IN PVOID PortContext OPTIONAL, IN PPORT_MESSAGE ReplyMessage, IN BOOLEAN AcceptConnection, IN OUT PPORT_VIEW ServerView OPTIONAL, OUT PREMOTE_PORT_VIEW ClientView OPTIONAL) { NTSTATUS Status; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); PORT_VIEW CapturedServerView; PORT_MESSAGE CapturedReplyMessage; ULONG ConnectionInfoLength; PLPCP_PORT_OBJECT ConnectionPort, ServerPort, ClientPort; PLPCP_CONNECTION_MESSAGE ConnectMessage; PLPCP_MESSAGE Message; PVOID ClientSectionToMap = NULL; HANDLE Handle; PEPROCESS ClientProcess; PETHREAD ClientThread; LARGE_INTEGER SectionOffset; PAGED_CODE(); LPCTRACE(LPC_COMPLETE_DEBUG, "Context: %p. Message: %p. Accept: %lx. Views: %p/%p\n", PortContext, ReplyMessage, AcceptConnection, ClientView, ServerView); /* Check if the call comes from user mode */ if (PreviousMode != KernelMode) { _SEH2_TRY { /* Probe the PortHandle */ ProbeForWriteHandle(PortHandle); /* Probe the basic ReplyMessage structure */ ProbeForRead(ReplyMessage, sizeof(*ReplyMessage), sizeof(ULONG)); CapturedReplyMessage = *(volatile PORT_MESSAGE*)ReplyMessage; ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength; /* Probe the connection info */ ProbeForRead(ReplyMessage + 1, ConnectionInfoLength, 1); /* The following parameters are optional */ /* Capture the server view */ if (ServerView) { ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG)); CapturedServerView = *(volatile PORT_VIEW*)ServerView; /* Validate the size of the server view */ if (CapturedServerView.Length != sizeof(CapturedServerView)) { /* Invalid size */ _SEH2_YIELD(return STATUS_INVALID_PARAMETER); } } /* Capture the client view */ if (ClientView) { ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG)); /* Validate the size of the client view */ if (((volatile REMOTE_PORT_VIEW*)ClientView)->Length != sizeof(*ClientView)) { /* Invalid size */ _SEH2_YIELD(return STATUS_INVALID_PARAMETER); } } } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { /* There was an exception, return the exception code */ _SEH2_YIELD(return _SEH2_GetExceptionCode()); } _SEH2_END; } else { CapturedReplyMessage = *ReplyMessage; ConnectionInfoLength = CapturedReplyMessage.u1.s1.DataLength; /* Capture the server view */ if (ServerView) { /* Validate the size of the server view */ if (ServerView->Length != sizeof(*ServerView)) { /* Invalid size */ return STATUS_INVALID_PARAMETER; } CapturedServerView = *ServerView; } /* Capture the client view */ if (ClientView) { /* Validate the size of the client view */ if (ClientView->Length != sizeof(*ClientView)) { /* Invalid size */ return STATUS_INVALID_PARAMETER; } } } /* Get the client process and thread */ Status = PsLookupProcessThreadByCid(&CapturedReplyMessage.ClientId, &ClientProcess, &ClientThread); if (!NT_SUCCESS(Status)) return Status; /* Acquire the LPC Lock */ KeAcquireGuardedMutex(&LpcpLock); /* Make sure that the client wants a reply, and this is the right one */ if (!(LpcpGetMessageFromThread(ClientThread)) || !(CapturedReplyMessage.MessageId) || (ClientThread->LpcReplyMessageId != CapturedReplyMessage.MessageId)) { /* Not the reply asked for, or no reply wanted, fail */ KeReleaseGuardedMutex(&LpcpLock); ObDereferenceObject(ClientProcess); ObDereferenceObject(ClientThread); return STATUS_REPLY_MESSAGE_MISMATCH; } /* Now get the message and connection message */ Message = LpcpGetMessageFromThread(ClientThread); ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); /* Get the client and connection port as well */ ClientPort = ConnectMessage->ClientPort; ConnectionPort = ClientPort->ConnectionPort; /* Make sure that the reply is being sent to the proper server process */ if (ConnectionPort->ServerProcess != PsGetCurrentProcess()) { /* It's not, so fail */ KeReleaseGuardedMutex(&LpcpLock); ObDereferenceObject(ClientProcess); ObDereferenceObject(ClientThread); return STATUS_REPLY_MESSAGE_MISMATCH; } /* At this point, don't let other accept attempts happen */ ClientThread->LpcReplyMessage = NULL; ClientThread->LpcReplyMessageId = 0; /* Clear the client port for now as well, then release the lock */ ConnectMessage->ClientPort = NULL; KeReleaseGuardedMutex(&LpcpLock); /* Check the connection information length */ if (ConnectionInfoLength > ConnectionPort->MaxConnectionInfoLength) { /* Normalize it since it's too large */ ConnectionInfoLength = ConnectionPort->MaxConnectionInfoLength; } /* Set the sizes of our reply message */ Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength + sizeof(LPCP_CONNECTION_MESSAGE); Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) + Message->Request.u1.s1.DataLength; /* Setup the reply message */ Message->Request.u2.s2.Type = LPC_REPLY; Message->Request.u2.s2.DataInfoOffset = 0; Message->Request.ClientId = CapturedReplyMessage.ClientId; Message->Request.MessageId = CapturedReplyMessage.MessageId; Message->Request.ClientViewSize = 0; _SEH2_TRY { RtlCopyMemory(ConnectMessage + 1, ReplyMessage + 1, ConnectionInfoLength); } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { Status = _SEH2_GetExceptionCode(); _SEH2_YIELD(goto Cleanup); } _SEH2_END; /* At this point, if the caller refused the connection, go to cleanup */ if (!AcceptConnection) { DPRINT1("LPC connection was refused\n"); goto Cleanup; } /* Otherwise, create the actual port */ Status = ObCreateObject(PreviousMode, LpcPortObjectType, NULL, PreviousMode, NULL, sizeof(LPCP_PORT_OBJECT), 0, 0, (PVOID*)&ServerPort); if (!NT_SUCCESS(Status)) goto Cleanup; /* Set it up */ RtlZeroMemory(ServerPort, sizeof(LPCP_PORT_OBJECT)); ServerPort->PortContext = PortContext; ServerPort->Flags = LPCP_COMMUNICATION_PORT; ServerPort->MaxMessageLength = ConnectionPort->MaxMessageLength; InitializeListHead(&ServerPort->LpcReplyChainHead); InitializeListHead(&ServerPort->LpcDataInfoChainHead); /* Reference the connection port until we're fully setup */ ObReferenceObject(ConnectionPort); /* Link the ports together */ ServerPort->ConnectionPort = ConnectionPort; ServerPort->ConnectedPort = ClientPort; ClientPort->ConnectedPort = ServerPort; /* Also set the creator CID */ ServerPort->Creator = PsGetCurrentThread()->Cid; ClientPort->Creator = Message->Request.ClientId; /* Get the section associated and then clear it, while inside the lock */ KeAcquireGuardedMutex(&LpcpLock); ClientSectionToMap = ConnectMessage->SectionToMap; ConnectMessage->SectionToMap = NULL; KeReleaseGuardedMutex(&LpcpLock); /* Now check if there's a client section */ if (ClientSectionToMap) { /* Setup the offset */ SectionOffset.QuadPart = ConnectMessage->ClientView.SectionOffset; /* Map the section */ Status = MmMapViewOfSection(ClientSectionToMap, PsGetCurrentProcess(), &ServerPort->ClientSectionBase, 0, 0, &SectionOffset, &ConnectMessage->ClientView.ViewSize, ViewUnmap, 0, PAGE_READWRITE); /* Update the offset and check for mapping status */ ConnectMessage->ClientView.SectionOffset = SectionOffset.LowPart; if (NT_SUCCESS(Status)) { /* Set the view base */ ConnectMessage->ClientView.ViewRemoteBase = ServerPort-> ClientSectionBase; /* Save and reference the mapping process */ ServerPort->MappingProcess = PsGetCurrentProcess(); ObReferenceObject(ServerPort->MappingProcess); } else { /* Otherwise, quit */ ObDereferenceObject(ServerPort); DPRINT1("Client section mapping failed: %lx\n", Status); LPCTRACE(LPC_COMPLETE_DEBUG, "View base, offset, size: %p %lx %p\n", ServerPort->ClientSectionBase, ConnectMessage->ClientView.ViewSize, SectionOffset); goto Cleanup; } } /* Check if there's a server section */ if (ServerView) { /* FIXME: TODO */ UNREFERENCED_PARAMETER(CapturedServerView); ASSERT(FALSE); } /* Reference the server port until it's fully inserted */ ObReferenceObject(ServerPort); /* Insert the server port in the namespace */ Status = ObInsertObject(ServerPort, NULL, PORT_ALL_ACCESS, 0, NULL, &Handle); if (!NT_SUCCESS(Status)) { /* We failed, remove the extra reference and cleanup */ ObDereferenceObject(ServerPort); goto Cleanup; } /* Enter SEH to write back the results */ _SEH2_TRY { /* Check if the caller gave a client view */ if (ClientView) { /* Fill it out */ ClientView->ViewBase = ConnectMessage->ClientView.ViewRemoteBase; ClientView->ViewSize = ConnectMessage->ClientView.ViewSize; } /* Return the handle to user mode */ *PortHandle = Handle; } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { /* Cleanup and return the exception code */ ObCloseHandle(Handle, PreviousMode); ObDereferenceObject(ServerPort); Status = _SEH2_GetExceptionCode(); _SEH2_YIELD(goto Cleanup); } _SEH2_END; LPCTRACE(LPC_COMPLETE_DEBUG, "Handle: %p. Messages: %p/%p. Ports: %p/%p/%p\n", Handle, Message, ConnectMessage, ServerPort, ClientPort, ConnectionPort); /* If there was no port context, use the handle by default */ if (!PortContext) ServerPort->PortContext = Handle; ServerPort->ClientThread = ClientThread; /* Set this message as the LPC Reply message while holding the lock */ KeAcquireGuardedMutex(&LpcpLock); ClientThread->LpcReplyMessage = Message; KeReleaseGuardedMutex(&LpcpLock); /* Clear the thread pointer so it doesn't get cleaned later */ ClientThread = NULL; /* Remove the extra reference we had added */ ObDereferenceObject(ServerPort); Cleanup: /* If there was a section, dereference it */ if (ClientSectionToMap) ObDereferenceObject(ClientSectionToMap); /* Check if we got here while still having a client thread */ if (ClientThread) { KeAcquireGuardedMutex(&LpcpLock); ClientThread->LpcReplyMessage = Message; LpcpPrepareToWakeClient(ClientThread); KeReleaseGuardedMutex(&LpcpLock); LpcpCompleteWait(&ClientThread->LpcReplySemaphore); ObDereferenceObject(ClientThread); } /* Dereference the client port if we have one, and the process */ LPCTRACE(LPC_COMPLETE_DEBUG, "Status: %lx. Thread: %p. Process: [%.16s]\n", Status, ClientThread, ClientProcess->ImageFileName); if (ClientPort) ObDereferenceObject(ClientPort); ObDereferenceObject(ClientProcess); return Status; } /* * @implemented */ NTSTATUS NTAPI NtCompleteConnectPort(IN HANDLE PortHandle) { NTSTATUS Status; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); PLPCP_PORT_OBJECT Port; PETHREAD Thread; PAGED_CODE(); LPCTRACE(LPC_COMPLETE_DEBUG, "Handle: %p\n", PortHandle); /* Get the Port Object */ Status = ObReferenceObjectByHandle(PortHandle, PORT_ALL_ACCESS, LpcPortObjectType, PreviousMode, (PVOID*)&Port, NULL); if (!NT_SUCCESS(Status)) return Status; /* Make sure this is a connection port */ if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_COMMUNICATION_PORT) { /* It isn't, fail */ ObDereferenceObject(Port); return STATUS_INVALID_PORT_HANDLE; } /* Acquire the lock */ KeAcquireGuardedMutex(&LpcpLock); /* Make sure we have a client thread */ if (!Port->ClientThread) { /* We don't, fail */ KeReleaseGuardedMutex(&LpcpLock); ObDereferenceObject(Port); return STATUS_INVALID_PARAMETER; } /* Get the thread */ Thread = Port->ClientThread; /* Make sure it has a reply message */ if (!LpcpGetMessageFromThread(Thread)) { /* It doesn't, quit */ KeReleaseGuardedMutex(&LpcpLock); ObDereferenceObject(Port); return STATUS_SUCCESS; } /* Clear the client thread and wake it up */ Port->ClientThread = NULL; LpcpPrepareToWakeClient(Thread); /* Release the lock and wait for an answer */ KeReleaseGuardedMutex(&LpcpLock); LpcpCompleteWait(&Thread->LpcReplySemaphore); /* Dereference the Thread and Port and return */ ObDereferenceObject(Port); ObDereferenceObject(Thread); LPCTRACE(LPC_COMPLETE_DEBUG, "Port: %p. Thread: %p\n", Port, Thread); return Status; } /* EOF */