/* * PROJECT: ReactOS Kernel * LICENSE: GPL - See COPYING in the top level directory * FILE: ntoskrnl/lpc/connect.c * PURPOSE: Local Procedure Call: Connection Management * PROGRAMMERS: Alex Ionescu (alex.ionescu@reactos.org) */ /* INCLUDES ******************************************************************/ #include #define NDEBUG #include /* PRIVATE FUNCTIONS *********************************************************/ PVOID NTAPI LpcpFreeConMsg(IN OUT PLPCP_MESSAGE *Message, IN OUT PLPCP_CONNECTION_MESSAGE *ConnectMessage, IN PETHREAD CurrentThread) { PVOID SectionToMap; PLPCP_MESSAGE ReplyMessage; /* Acquire the LPC lock */ KeAcquireGuardedMutex(&LpcpLock); /* Check if the reply chain is not empty */ if (!IsListEmpty(&CurrentThread->LpcReplyChain)) { /* Remove this entry and re-initialize it */ RemoveEntryList(&CurrentThread->LpcReplyChain); InitializeListHead(&CurrentThread->LpcReplyChain); } /* Check if there's a reply message */ ReplyMessage = LpcpGetMessageFromThread(CurrentThread); if (ReplyMessage) { /* Get the message */ *Message = ReplyMessage; /* Check if it's got messages */ if (!IsListEmpty(&ReplyMessage->Entry)) { /* Clear the list */ RemoveEntryList(&ReplyMessage->Entry); InitializeListHead(&ReplyMessage->Entry); } /* Clear message data */ CurrentThread->LpcReceivedMessageId = 0; CurrentThread->LpcReplyMessage = NULL; /* Get the connection message and clear the section */ *ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(ReplyMessage + 1); SectionToMap = (*ConnectMessage)->SectionToMap; (*ConnectMessage)->SectionToMap = NULL; } else { /* No message to return */ *Message = NULL; SectionToMap = NULL; } /* Release the lock and return the section */ KeReleaseGuardedMutex(&LpcpLock); return SectionToMap; } /* PUBLIC FUNCTIONS **********************************************************/ /* * @implemented */ NTSTATUS NTAPI NtSecureConnectPort(OUT PHANDLE PortHandle, IN PUNICODE_STRING PortName, IN PSECURITY_QUALITY_OF_SERVICE SecurityQos, IN OUT PPORT_VIEW ClientView OPTIONAL, IN PSID ServerSid OPTIONAL, IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL, OUT PULONG MaxMessageLength OPTIONAL, IN OUT PVOID ConnectionInformation OPTIONAL, IN OUT PULONG ConnectionInformationLength OPTIONAL) { NTSTATUS Status = STATUS_SUCCESS; KPROCESSOR_MODE PreviousMode = KeGetPreviousMode(); PETHREAD Thread = PsGetCurrentThread(); #if DBG UNICODE_STRING CapturedPortName; #endif SECURITY_QUALITY_OF_SERVICE CapturedQos; PORT_VIEW CapturedClientView; PSID CapturedServerSid; ULONG ConnectionInfoLength = 0; PLPCP_PORT_OBJECT Port, ClientPort; PLPCP_MESSAGE Message; PLPCP_CONNECTION_MESSAGE ConnectMessage; ULONG PortMessageLength; HANDLE Handle; PVOID SectionToMap; LARGE_INTEGER SectionOffset; PTOKEN Token; PTOKEN_USER TokenUserInfo; PAGED_CODE(); /* Check if the call comes from user mode */ if (PreviousMode != KernelMode) { /* Enter SEH for probing the parameters */ _SEH2_TRY { /* Probe the PortHandle */ ProbeForWriteHandle(PortHandle); /* Probe and capture the QoS */ ProbeForRead(SecurityQos, sizeof(*SecurityQos), sizeof(ULONG)); CapturedQos = *(volatile SECURITY_QUALITY_OF_SERVICE*)SecurityQos; /* NOTE: Do not care about CapturedQos.Length */ /* The following parameters are optional */ /* Capture the client view */ if (ClientView) { ProbeForWrite(ClientView, sizeof(*ClientView), sizeof(ULONG)); CapturedClientView = *(volatile PORT_VIEW*)ClientView; /* Validate the size of the client view */ if (CapturedClientView.Length != sizeof(CapturedClientView)) { /* Invalid size */ _SEH2_YIELD(return STATUS_INVALID_PARAMETER); } } /* Capture the server view */ if (ServerView) { ProbeForWrite(ServerView, sizeof(*ServerView), sizeof(ULONG)); /* Validate the size of the server view */ if (((volatile REMOTE_PORT_VIEW*)ServerView)->Length != sizeof(*ServerView)) { /* Invalid size */ _SEH2_YIELD(return STATUS_INVALID_PARAMETER); } } if (MaxMessageLength) ProbeForWriteUlong(MaxMessageLength); /* Capture connection information length */ if (ConnectionInformationLength) { ProbeForWriteUlong(ConnectionInformationLength); ConnectionInfoLength = *(volatile ULONG*)ConnectionInformationLength; } /* Probe the ConnectionInformation */ if (ConnectionInformation) ProbeForWrite(ConnectionInformation, ConnectionInfoLength, sizeof(ULONG)); CapturedServerSid = ServerSid; if (ServerSid) { /* Capture it */ Status = SepCaptureSid(ServerSid, PreviousMode, PagedPool, TRUE, &CapturedServerSid); if (!NT_SUCCESS(Status)) { DPRINT1("Failed to capture ServerSid!\n"); _SEH2_YIELD(return Status); } } } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { /* There was an exception, return the exception code */ _SEH2_YIELD(return _SEH2_GetExceptionCode()); } _SEH2_END; } else { CapturedQos = *SecurityQos; /* NOTE: Do not care about CapturedQos.Length */ /* The following parameters are optional */ /* Capture the client view */ if (ClientView) { /* Validate the size of the client view */ if (ClientView->Length != sizeof(*ClientView)) { /* Invalid size */ return STATUS_INVALID_PARAMETER; } CapturedClientView = *ClientView; } /* Capture the server view */ if (ServerView) { /* Validate the size of the server view */ if (ServerView->Length != sizeof(*ServerView)) { /* Invalid size */ return STATUS_INVALID_PARAMETER; } } /* Capture connection information length */ if (ConnectionInformationLength) ConnectionInfoLength = *ConnectionInformationLength; CapturedServerSid = ServerSid; } #if DBG /* Capture the port name for DPRINT only - ObReferenceObjectByName does * its own capture. As it is used only for debugging, ignore any failure; * the string is zeroed out in such case. */ ProbeAndCaptureUnicodeString(&CapturedPortName, PreviousMode, PortName); LPCTRACE(LPC_CONNECT_DEBUG, "Name: %wZ. SecurityQos: %p. Views: %p/%p. Sid: %p\n", &CapturedPortName, SecurityQos, ClientView, ServerView, ServerSid); #endif /* Get the port */ Status = ObReferenceObjectByName(PortName, 0, NULL, PORT_CONNECT, LpcPortObjectType, PreviousMode, NULL, (PVOID*)&Port); if (!NT_SUCCESS(Status)) { #if DBG DPRINT1("Failed to reference port '%wZ': 0x%lx\n", &CapturedPortName, Status); ReleaseCapturedUnicodeString(&CapturedPortName, PreviousMode); #endif if (CapturedServerSid != ServerSid) SepReleaseSid(CapturedServerSid, PreviousMode, TRUE); return Status; } /* This has to be a connection port */ if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT) { #if DBG DPRINT1("Port '%wZ' is not a connection port (Flags: 0x%lx)\n", &CapturedPortName, Port->Flags); ReleaseCapturedUnicodeString(&CapturedPortName, PreviousMode); #endif /* It isn't, so fail */ ObDereferenceObject(Port); if (CapturedServerSid != ServerSid) SepReleaseSid(CapturedServerSid, PreviousMode, TRUE); return STATUS_INVALID_PORT_HANDLE; } /* Check if we have a (captured) SID */ if (ServerSid) { /* Make sure that we have a server */ if (Port->ServerProcess) { /* Get its token and query user information */ Token = PsReferencePrimaryToken(Port->ServerProcess); Status = SeQueryInformationToken(Token, TokenUser, (PVOID*)&TokenUserInfo); PsDereferencePrimaryToken(Token); /* Check for success */ if (NT_SUCCESS(Status)) { /* Compare the SIDs */ if (!RtlEqualSid(CapturedServerSid, TokenUserInfo->User.Sid)) { /* Fail */ #if DBG DPRINT1("Port '%wZ': server SID mismatch\n", &CapturedPortName); #endif Status = STATUS_SERVER_SID_MISMATCH; } /* Free token information */ ExFreePoolWithTag(TokenUserInfo, TAG_SE); } } else { /* Invalid SID */ #if DBG DPRINT1("Port '%wZ': server SID mismatch\n", &CapturedPortName); #endif Status = STATUS_SERVER_SID_MISMATCH; } /* Finally release the captured SID, we don't need it anymore */ if (CapturedServerSid != ServerSid) SepReleaseSid(CapturedServerSid, PreviousMode, TRUE); } #if DBG ReleaseCapturedUnicodeString(&CapturedPortName, PreviousMode); #endif /* Check if SID failed */ if (ServerSid && !NT_SUCCESS(Status)) { /* Quit */ ObDereferenceObject(Port); return Status; } /* Create the client port */ Status = ObCreateObject(PreviousMode, LpcPortObjectType, NULL, PreviousMode, NULL, sizeof(LPCP_PORT_OBJECT), 0, 0, (PVOID*)&ClientPort); if (!NT_SUCCESS(Status)) { /* Failed, dereference the server port and return */ DPRINT1("Failed to create Port object: 0x%lx\n", Status); ObDereferenceObject(Port); return Status; } /* * Setup the client port -- From now on, dereferencing the client port * will automatically dereference the connection port too. */ RtlZeroMemory(ClientPort, sizeof(LPCP_PORT_OBJECT)); ClientPort->Flags = LPCP_CLIENT_PORT; ClientPort->ConnectionPort = Port; ClientPort->MaxMessageLength = Port->MaxMessageLength; ClientPort->SecurityQos = CapturedQos; InitializeListHead(&ClientPort->LpcReplyChainHead); InitializeListHead(&ClientPort->LpcDataInfoChainHead); /* Check if we have dynamic security */ if (CapturedQos.ContextTrackingMode == SECURITY_DYNAMIC_TRACKING) { /* Remember that */ ClientPort->Flags |= LPCP_SECURITY_DYNAMIC; } else { /* Create our own client security */ Status = SeCreateClientSecurity(Thread, &CapturedQos, FALSE, &ClientPort->StaticSecurity); if (!NT_SUCCESS(Status)) { /* Security failed, dereference and return */ DPRINT1("SeCreateClientSecurity failed: 0x%lx\n", Status); ObDereferenceObject(ClientPort); return Status; } } /* Initialize the port queue */ Status = LpcpInitializePortQueue(ClientPort); if (!NT_SUCCESS(Status)) { /* Failed */ DPRINT1("LpcpInitializePortQueue failed: 0x%lx\n", Status); ObDereferenceObject(ClientPort); return Status; } /* Check if we have a client view */ if (ClientView) { /* Get the section handle */ Status = ObReferenceObjectByHandle(CapturedClientView.SectionHandle, SECTION_MAP_READ | SECTION_MAP_WRITE, MmSectionObjectType, PreviousMode, (PVOID*)&SectionToMap, NULL); if (!NT_SUCCESS(Status)) { /* Fail */ DPRINT1("Failed to reference port section handle: 0x%lx\n", Status); ObDereferenceObject(ClientPort); return Status; } /* Set the section offset */ SectionOffset.QuadPart = CapturedClientView.SectionOffset; /* Map it */ Status = MmMapViewOfSection(SectionToMap, PsGetCurrentProcess(), &ClientPort->ClientSectionBase, 0, 0, &SectionOffset, &CapturedClientView.ViewSize, ViewUnmap, 0, PAGE_READWRITE); /* Update the offset */ CapturedClientView.SectionOffset = SectionOffset.LowPart; /* Check for failure */ if (!NT_SUCCESS(Status)) { /* Fail */ DPRINT1("Failed to map port section: 0x%lx\n", Status); ObDereferenceObject(SectionToMap); ObDereferenceObject(ClientPort); return Status; } /* Update the base */ CapturedClientView.ViewBase = ClientPort->ClientSectionBase; /* Reference and remember the process */ ClientPort->MappingProcess = PsGetCurrentProcess(); ObReferenceObject(ClientPort->MappingProcess); } else { /* No section */ SectionToMap = NULL; } /* Normalize connection information */ if (ConnectionInfoLength > Port->MaxConnectionInfoLength) { /* Use the port's maximum allowed value */ ConnectionInfoLength = Port->MaxConnectionInfoLength; } /* Allocate a message from the port zone */ Message = LpcpAllocateFromPortZone(); if (!Message) { /* Fail if we couldn't allocate a message */ DPRINT1("LpcpAllocateFromPortZone failed\n"); if (SectionToMap) ObDereferenceObject(SectionToMap); ObDereferenceObject(ClientPort); return STATUS_NO_MEMORY; } /* Set pointer to the connection message and fill in the CID */ ConnectMessage = (PLPCP_CONNECTION_MESSAGE)(Message + 1); Message->Request.ClientId = Thread->Cid; /* Check if we have a client view */ if (ClientView) { /* Set the view size */ Message->Request.ClientViewSize = CapturedClientView.ViewSize; /* Copy the client view and clear the server view */ RtlCopyMemory(&ConnectMessage->ClientView, &CapturedClientView, sizeof(CapturedClientView)); RtlZeroMemory(&ConnectMessage->ServerView, sizeof(REMOTE_PORT_VIEW)); } else { /* Set the size to 0 and clear the connect message */ Message->Request.ClientViewSize = 0; RtlZeroMemory(ConnectMessage, sizeof(LPCP_CONNECTION_MESSAGE)); } /* Set the section and client port. Port is NULL for now */ ConnectMessage->ClientPort = NULL; ConnectMessage->SectionToMap = SectionToMap; /* Set the data for the connection request message */ Message->Request.u1.s1.DataLength = (CSHORT)ConnectionInfoLength + sizeof(LPCP_CONNECTION_MESSAGE); Message->Request.u1.s1.TotalLength = sizeof(LPCP_MESSAGE) + Message->Request.u1.s1.DataLength; Message->Request.u2.s2.Type = LPC_CONNECTION_REQUEST; /* Check if we have connection information */ if (ConnectionInformation) { _SEH2_TRY { /* Copy it in */ RtlCopyMemory(ConnectMessage + 1, ConnectionInformation, ConnectionInfoLength); } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { DPRINT1("Exception 0x%lx when copying connection info to user mode\n", _SEH2_GetExceptionCode()); /* Cleanup and return the exception code */ /* Free the message we have */ LpcpFreeToPortZone(Message, 0); /* Dereference other objects */ if (SectionToMap) ObDereferenceObject(SectionToMap); ObDereferenceObject(ClientPort); /* Return status */ _SEH2_YIELD(return _SEH2_GetExceptionCode()); } _SEH2_END; } /* Reset the status code */ Status = STATUS_SUCCESS; /* Acquire the port lock */ KeAcquireGuardedMutex(&LpcpLock); /* Check if someone already deleted the port name */ if (Port->Flags & LPCP_NAME_DELETED) { /* Fail the request */ Status = STATUS_OBJECT_NAME_NOT_FOUND; } else { /* Associate no thread yet */ Message->RepliedToThread = NULL; /* Generate the Message ID and set it */ Message->Request.MessageId = LpcpNextMessageId++; if (!LpcpNextMessageId) LpcpNextMessageId = 1; Thread->LpcReplyMessageId = Message->Request.MessageId; /* Insert the message into the queue and thread chain */ InsertTailList(&Port->MsgQueue.ReceiveHead, &Message->Entry); InsertTailList(&Port->LpcReplyChainHead, &Thread->LpcReplyChain); Thread->LpcReplyMessage = Message; /* Now we can finally reference the client port and link it */ ObReferenceObject(ClientPort); ConnectMessage->ClientPort = ClientPort; /* Enter a critical region */ KeEnterCriticalRegion(); } /* Add another reference to the port */ ObReferenceObject(Port); /* Release the lock */ KeReleaseGuardedMutex(&LpcpLock); /* Check for success */ if (NT_SUCCESS(Status)) { LPCTRACE(LPC_CONNECT_DEBUG, "Messages: %p/%p. Ports: %p/%p. Status: %lx\n", Message, ConnectMessage, Port, ClientPort, Status); /* If this is a waitable port, set the event */ if (Port->Flags & LPCP_WAITABLE_PORT) KeSetEvent(&Port->WaitEvent, 1, FALSE); /* Release the queue semaphore and leave the critical region */ LpcpCompleteWait(Port->MsgQueue.Semaphore); KeLeaveCriticalRegion(); /* Now wait for a reply and set 'Status' */ LpcpConnectWait(&Thread->LpcReplySemaphore, PreviousMode); } /* Now, always free the connection message */ SectionToMap = LpcpFreeConMsg(&Message, &ConnectMessage, Thread); /* Check for failure */ if (!NT_SUCCESS(Status)) { /* Check if the semaphore got signaled in the meantime */ if (KeReadStateSemaphore(&Thread->LpcReplySemaphore)) { /* Wait on it */ KeWaitForSingleObject(&Thread->LpcReplySemaphore, WrExecutive, KernelMode, FALSE, NULL); } goto Failure; } /* Check if we got a message back */ if (Message) { /* Check for new return length */ if ((Message->Request.u1.s1.DataLength - sizeof(LPCP_CONNECTION_MESSAGE)) < ConnectionInfoLength) { /* Set new normalized connection length */ ConnectionInfoLength = Message->Request.u1.s1.DataLength - sizeof(LPCP_CONNECTION_MESSAGE); } /* Check if the caller had connection information */ if (ConnectionInformation) { _SEH2_TRY { /* Return the connection information length if needed */ if (ConnectionInformationLength) *ConnectionInformationLength = ConnectionInfoLength; /* Return the connection information */ RtlCopyMemory(ConnectionInformation, ConnectMessage + 1, ConnectionInfoLength); } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { /* Cleanup and return the exception code */ Status = _SEH2_GetExceptionCode(); _SEH2_YIELD(goto Failure); } _SEH2_END; } /* Make sure we had a connected port */ if (ClientPort->ConnectedPort) { /* Get the message length before the port might get killed */ PortMessageLength = Port->MaxMessageLength; /* Insert the client port */ Status = ObInsertObject(ClientPort, NULL, PORT_ALL_ACCESS, 0, NULL, &Handle); if (NT_SUCCESS(Status)) { LPCTRACE(LPC_CONNECT_DEBUG, "Handle: %p. Length: %lx\n", Handle, PortMessageLength); _SEH2_TRY { /* Return the handle */ *PortHandle = Handle; /* Check if maximum length was requested */ if (MaxMessageLength) *MaxMessageLength = PortMessageLength; /* Check if we had a client view */ if (ClientView) { /* Copy it back */ RtlCopyMemory(ClientView, &ConnectMessage->ClientView, sizeof(*ClientView)); } /* Check if we had a server view */ if (ServerView) { /* Copy it back */ RtlCopyMemory(ServerView, &ConnectMessage->ServerView, sizeof(*ServerView)); } } _SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { /* An exception happened, close the opened handle */ ObCloseHandle(Handle, PreviousMode); Status = _SEH2_GetExceptionCode(); } _SEH2_END; } } else { /* No connection port, we failed */ if (SectionToMap) ObDereferenceObject(SectionToMap); /* Acquire the lock */ KeAcquireGuardedMutex(&LpcpLock); /* Check if it's because the name got deleted */ if (!(ClientPort->ConnectionPort) || (Port->Flags & LPCP_NAME_DELETED)) { /* Set the correct status */ Status = STATUS_OBJECT_NAME_NOT_FOUND; } else { /* Otherwise, the caller refused us */ Status = STATUS_PORT_CONNECTION_REFUSED; } /* Release the lock */ KeReleaseGuardedMutex(&LpcpLock); /* Kill the port */ ObDereferenceObject(ClientPort); } /* Free the message */ LpcpFreeToPortZone(Message, 0); } else { /* No reply message, fail */ Status = STATUS_PORT_CONNECTION_REFUSED; goto Failure; } ObDereferenceObject(Port); /* Return status */ return Status; Failure: /* Check if we had a message and free it */ if (Message) LpcpFreeToPortZone(Message, 0); /* Dereference other objects */ if (SectionToMap) ObDereferenceObject(SectionToMap); ObDereferenceObject(ClientPort); ObDereferenceObject(Port); /* Return status */ return Status; } /* * @implemented */ NTSTATUS NTAPI NtConnectPort(OUT PHANDLE PortHandle, IN PUNICODE_STRING PortName, IN PSECURITY_QUALITY_OF_SERVICE SecurityQos, IN OUT PPORT_VIEW ClientView OPTIONAL, IN OUT PREMOTE_PORT_VIEW ServerView OPTIONAL, OUT PULONG MaxMessageLength OPTIONAL, IN OUT PVOID ConnectionInformation OPTIONAL, IN OUT PULONG ConnectionInformationLength OPTIONAL) { /* Call the newer API */ return NtSecureConnectPort(PortHandle, PortName, SecurityQos, ClientView, NULL, ServerView, MaxMessageLength, ConnectionInformation, ConnectionInformationLength); } /* EOF */