[ntoskrnl/lpc]

- Implement NtReplyPort based on NtReplyWaitReceivePortEx and LpcReplyPort.
- Implement NtRequestPort based on NtRequestWaitReplyPort and LpcRequestPort.

svn path=/trunk/; revision=43603
This commit is contained in:
Aleksey Bragin 2009-10-19 15:49:29 +00:00
parent b8f74a1482
commit 884aa8948f
2 changed files with 353 additions and 6 deletions

View file

@ -141,10 +141,172 @@ LpcpMoveMessage(IN PPORT_MESSAGE Destination,
NTSTATUS
NTAPI
NtReplyPort(IN HANDLE PortHandle,
IN PPORT_MESSAGE LpcReply)
IN PPORT_MESSAGE ReplyMessage)
{
UNIMPLEMENTED;
return STATUS_NOT_IMPLEMENTED;
PLPCP_PORT_OBJECT Port, ConnectionPort = NULL;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
NTSTATUS Status;
PLPCP_MESSAGE Message;
PETHREAD Thread = PsGetCurrentThread(), WakeupThread;
//PORT_MESSAGE CapturedReplyMessage;
PAGED_CODE();
LPCTRACE(LPC_REPLY_DEBUG,
"Handle: %lx. Message: %p.\n",
PortHandle,
ReplyMessage);
if (KeGetPreviousMode() == UserMode)
{
_SEH2_TRY
{
if (ReplyMessage != NULL)
{
ProbeForRead(ReplyMessage, sizeof(PORT_MESSAGE), sizeof(ULONG));
/*RtlCopyMemory(&CapturedReplyMessage, ReplyMessage, sizeof(PORT_MESSAGE));
ReplyMessage = &CapturedReplyMessage;*/
}
}
_SEH2_EXCEPT(ExSystemExceptionFilter())
{
DPRINT1("SEH crash [1]\n");
DbgBreakPoint();
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
}
/* Validate its length */
if (((ULONG)ReplyMessage->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)ReplyMessage->u1.s1.TotalLength)
{
/* Fail */
return STATUS_INVALID_PARAMETER;
}
/* Make sure it has a valid ID */
if (!ReplyMessage->MessageId) return STATUS_INVALID_PARAMETER;
/* Get the Port object */
Status = ObReferenceObjectByHandle(PortHandle,
0,
LpcPortObjectType,
PreviousMode,
(PVOID*)&Port,
NULL);
if (!NT_SUCCESS(Status)) return Status;
/* Validate its length in respect to the port object */
if (((ULONG)ReplyMessage->u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)ReplyMessage->u1.s1.TotalLength <=
(ULONG)ReplyMessage->u1.s1.DataLength))
{
/* Too large, fail */
ObDereferenceObject(Port);
return STATUS_PORT_MESSAGE_TOO_LONG;
}
/* Get the ETHREAD corresponding to it */
Status = PsLookupProcessThreadByCid(&ReplyMessage->ClientId,
NULL,
&WakeupThread);
if (!NT_SUCCESS(Status))
{
/* No thread found, fail */
ObDereferenceObject(Port);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return Status;
}
/* Allocate a message from the port zone */
Message = LpcpAllocateFromPortZone();
if (!Message)
{
/* Fail if we couldn't allocate a message */
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(WakeupThread);
ObDereferenceObject(Port);
return STATUS_NO_MEMORY;
}
/* Keep the lock acquired */
KeAcquireGuardedMutex(&LpcpLock);
/* Make sure this is the reply the thread is waiting for */
if ((WakeupThread->LpcReplyMessageId != ReplyMessage->MessageId) ||
((LpcpGetMessageFromThread(WakeupThread)) &&
(LpcpGetMessageType(&LpcpGetMessageFromThread(WakeupThread)->
Request) != LPC_REQUEST)))
{
/* It isn't, fail */
LpcpFreeToPortZone(Message, 3);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(WakeupThread);
ObDereferenceObject(Port);
return STATUS_REPLY_MESSAGE_MISMATCH;
}
/* Copy the message */
_SEH2_TRY
{
LpcpMoveMessage(&Message->Request,
ReplyMessage,
ReplyMessage + 1,
LPC_REPLY,
NULL);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Fail */
LpcpFreeToPortZone(Message, 3);
ObDereferenceObject(WakeupThread);
ObDereferenceObject(Port);
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
/* Reference the thread while we use it */
ObReferenceObject(WakeupThread);
Message->RepliedToThread = WakeupThread;
/* Set this as the reply message */
WakeupThread->LpcReplyMessageId = 0;
WakeupThread->LpcReplyMessage = (PVOID)Message;
/* Check if we have messages on the reply chain */
if (!(WakeupThread->LpcExitThreadCalled) &&
!(IsListEmpty(&WakeupThread->LpcReplyChain)))
{
/* Remove us from it and reinitialize it */
RemoveEntryList(&WakeupThread->LpcReplyChain);
InitializeListHead(&WakeupThread->LpcReplyChain);
}
/* Check if this is the message the thread had received */
if ((Thread->LpcReceivedMsgIdValid) &&
(Thread->LpcReceivedMessageId == ReplyMessage->MessageId))
{
/* Clear this data */
Thread->LpcReceivedMessageId = 0;
Thread->LpcReceivedMsgIdValid = FALSE;
}
/* Free any data information */
LpcpFreeDataInfoMessage(Port,
ReplyMessage->MessageId,
ReplyMessage->CallbackId,
ReplyMessage->ClientId);
/* Release the lock and release the LPC semaphore to wake up waiters */
KeReleaseGuardedMutex(&LpcpLock);
LpcpCompleteWait(&WakeupThread->LpcReplySemaphore);
/* Now we can let go of the thread */
ObDereferenceObject(WakeupThread);
/* Dereference port object */
ObDereferenceObject(Port);
return Status;
}
/*

View file

@ -442,10 +442,195 @@ LpcRequestWaitReplyPort(IN PVOID PortObject,
NTSTATUS
NTAPI
NtRequestPort(IN HANDLE PortHandle,
IN PPORT_MESSAGE LpcMessage)
IN PPORT_MESSAGE LpcRequest)
{
UNIMPLEMENTED;
return STATUS_NOT_IMPLEMENTED;
PLPCP_PORT_OBJECT Port, QueuePort, ConnectionPort = NULL;
KPROCESSOR_MODE PreviousMode = KeGetPreviousMode();
NTSTATUS Status;
PLPCP_MESSAGE Message;
PETHREAD Thread = PsGetCurrentThread();
PKSEMAPHORE Semaphore;
ULONG MessageType;
PAGED_CODE();
LPCTRACE(LPC_SEND_DEBUG,
"Handle: %lx. Message: %p. Type: %lx\n",
PortHandle,
LpcRequest,
LpcpGetMessageType(LpcRequest));
/* Get the message type */
MessageType = LpcRequest->u2.s2.Type | LPC_DATAGRAM;
/* Can't have data information on this type of call */
if (LpcRequest->u2.s2.DataInfoOffset) return STATUS_INVALID_PARAMETER;
/* Validate the length */
if (((ULONG)LpcRequest->u1.s1.DataLength + sizeof(PORT_MESSAGE)) >
(ULONG)LpcRequest->u1.s1.TotalLength)
{
/* Fail */
return STATUS_INVALID_PARAMETER;
}
/* Reference the object */
Status = ObReferenceObjectByHandle(PortHandle,
0,
LpcPortObjectType,
PreviousMode,
(PVOID*)&Port,
NULL);
if (!NT_SUCCESS(Status)) return Status;
/* Validate the message length */
if (((ULONG)LpcRequest->u1.s1.TotalLength > Port->MaxMessageLength) ||
((ULONG)LpcRequest->u1.s1.TotalLength <= (ULONG)LpcRequest->u1.s1.DataLength))
{
/* Fail */
ObDereferenceObject(Port);
return STATUS_PORT_MESSAGE_TOO_LONG;
}
/* Allocate a message from the port zone */
Message = LpcpAllocateFromPortZone();
if (!Message)
{
/* Fail if we couldn't allocate a message */
ObDereferenceObject(Port);
return STATUS_NO_MEMORY;
}
/* No callback, just copy the message */
_SEH2_TRY
{
/* Copy it */
LpcpMoveMessage(&Message->Request,
LpcRequest,
LpcRequest + 1,
MessageType,
&Thread->Cid);
}
_SEH2_EXCEPT(EXCEPTION_EXECUTE_HANDLER)
{
/* Fail */
LpcpFreeToPortZone(Message, 0);
ObDereferenceObject(Port);
_SEH2_YIELD(return _SEH2_GetExceptionCode());
}
_SEH2_END;
/* Acquire the LPC lock */
KeAcquireGuardedMutex(&LpcpLock);
/* Right now clear the port context */
Message->PortContext = NULL;
/* Check if this is a not connection port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) != LPCP_CONNECTION_PORT)
{
/* We want the connected port */
QueuePort = Port->ConnectedPort;
if (!QueuePort)
{
/* We have no connected port, fail */
LpcpFreeToPortZone(Message, 3);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
/* Check if this is a communication port */
if ((Port->Flags & LPCP_PORT_TYPE_MASK) == LPCP_CLIENT_PORT)
{
/* Copy the port context and use the connection port */
Message->PortContext = QueuePort->PortContext;
ConnectionPort = QueuePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
LpcpFreeToPortZone(Message, 3);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
}
else if ((Port->Flags & LPCP_PORT_TYPE_MASK) !=
LPCP_COMMUNICATION_PORT)
{
/* Use the connection port for anything but communication ports */
ConnectionPort = QueuePort = Port->ConnectionPort;
if (!ConnectionPort)
{
/* Fail */
LpcpFreeToPortZone(Message, 3);
ObDereferenceObject(Port);
return STATUS_PORT_DISCONNECTED;
}
}
/* Reference the connection port if it exists */
if (ConnectionPort) ObReferenceObject(ConnectionPort);
}
else
{
/* Otherwise, for a connection port, use the same port object */
QueuePort = Port;
}
/* Reference QueuePort if we have it */
if (QueuePort && ObReferenceObjectSafe(QueuePort))
{
/* Set sender's port */
Message->SenderPort = Port;
/* Generate the Message ID and set it */
Message->Request.MessageId = LpcpNextMessageId++;
if (!LpcpNextMessageId) LpcpNextMessageId = 1;
Message->Request.CallbackId = 0;
/* No Message ID for the thread */
PsGetCurrentThread()->LpcReplyMessageId = 0;
/* Insert the message in our chain */
InsertTailList(&QueuePort->MsgQueue.ReceiveHead, &Message->Entry);
/* Release the lock and get the semaphore we'll use later */
KeEnterCriticalRegion();
KeReleaseGuardedMutex(&LpcpLock);
/* Now release the semaphore */
Semaphore = QueuePort->MsgQueue.Semaphore;
LpcpCompleteWait(Semaphore);
/* If this is a waitable port, wake it up */
if (QueuePort->Flags & LPCP_WAITABLE_PORT)
{
/* Wake it */
KeSetEvent(&QueuePort->WaitEvent, IO_NO_INCREMENT, FALSE);
}
KeLeaveCriticalRegion();
/* Dereference objects */
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
ObDereferenceObject(QueuePort);
ObDereferenceObject(Port);
LPCTRACE(LPC_SEND_DEBUG, "Port: %p. Message: %p\n", QueuePort, Message);
return STATUS_SUCCESS;
}
Status = STATUS_PORT_DISCONNECTED;
/* All done with a failure*/
LPCTRACE(LPC_SEND_DEBUG,
"Port: %p. Status: %p\n",
Port,
Status);
/* The wait failed, free the message */
if (Message) LpcpFreeToPortZone(Message, 3);
ObDereferenceObject(Port);
if (ConnectionPort) ObDereferenceObject(ConnectionPort);
return Status;
}
/*