[CSRSRV]: Improve ClientConnectionThread a bit to make it look a bit more like CSRSRV2 and add some extra functionality.

[CSRSRV]: Port from CSRSRV2 and use CsrApiPortInitialize instead of CsrpCreateListenPort. This will set appropriate SDs and also wait for all threads to be ready.

svn path=/trunk/; revision=55638
This commit is contained in:
Alex Ionescu 2012-02-16 16:40:15 +00:00
parent a6998d5930
commit d2af91632d
3 changed files with 221 additions and 82 deletions

View file

@ -19,6 +19,7 @@
static unsigned ApiDefinitionsCount = 0; static unsigned ApiDefinitionsCount = 0;
static PCSRSS_API_DEFINITION ApiDefinitions = NULL; static PCSRSS_API_DEFINITION ApiDefinitions = NULL;
UNICODE_STRING CsrApiPortName;
/* FUNCTIONS *****************************************************************/ /* FUNCTIONS *****************************************************************/
@ -303,6 +304,129 @@ CsrSrvAttachSharedSection(IN PCSR_PROCESS CsrProcess OPTIONAL,
return STATUS_SUCCESS; return STATUS_SUCCESS;
} }
/*++
* @name CsrApiPortInitialize
*
* The CsrApiPortInitialize routine initializes the LPC Port used for
* communications with the Client/Server Runtime (CSR) and initializes the
* static thread that will handle connection requests and APIs.
*
* @param None
*
* @return STATUS_SUCCESS in case of success, STATUS_UNSUCCESSFUL
* othwerwise.
*
* @remarks None.
*
*--*/
NTSTATUS
NTAPI
CsrApiPortInitialize(VOID)
{
ULONG Size;
OBJECT_ATTRIBUTES ObjectAttributes;
NTSTATUS Status;
HANDLE hRequestEvent, hThread;
CLIENT_ID ClientId;
PLIST_ENTRY ListHead, NextEntry;
PCSR_THREAD ServerThread;
/* Calculate how much space we'll need for the Port Name */
Size = CsrDirectoryName.Length + sizeof(CSR_PORT_NAME) + sizeof(WCHAR);
/* Create the buffer for it */
CsrApiPortName.Buffer = RtlAllocateHeap(CsrHeap, 0, Size);
if (!CsrApiPortName.Buffer) return STATUS_NO_MEMORY;
/* Setup the rest of the empty string */
CsrApiPortName.Length = 0;
CsrApiPortName.MaximumLength = (USHORT)Size;
RtlAppendUnicodeStringToString(&CsrApiPortName, &CsrDirectoryName);
RtlAppendUnicodeToString(&CsrApiPortName, UNICODE_PATH_SEP);
RtlAppendUnicodeToString(&CsrApiPortName, CSR_PORT_NAME);
if (CsrDebug & 1)
{
DPRINT1("CSRSS: Creating %wZ port and associated threads\n", &CsrApiPortName);
DPRINT1("CSRSS: sizeof( CONNECTINFO ) == %ld sizeof( API_MSG ) == %ld\n",
sizeof(CSR_CONNECTION_INFO), sizeof(CSR_API_MESSAGE));
}
/* FIXME: Create a Security Descriptor */
/* Initialize the Attributes */
InitializeObjectAttributes(&ObjectAttributes,
&CsrApiPortName,
0,
NULL,
NULL /* FIXME*/);
/* Create the Port Object */
Status = NtCreatePort(&hApiPort, //&CsrApiPort,
&ObjectAttributes,
sizeof(CSR_CONNECTION_INFO),
sizeof(CSR_API_MESSAGE),
16 * PAGE_SIZE);
if (NT_SUCCESS(Status))
{
/* Create the event the Port Thread will use */
Status = NtCreateEvent(&hRequestEvent,
EVENT_ALL_ACCESS,
NULL,
SynchronizationEvent,
FALSE);
if (NT_SUCCESS(Status))
{
/* Create the Request Thread */
Status = RtlCreateUserThread(NtCurrentProcess(),
NULL,
TRUE,
0,
0,
0,
(PVOID)ClientConnectionThread,//CsrApiRequestThread,
(PVOID)hRequestEvent,
&hThread,
&ClientId);
if (NT_SUCCESS(Status))
{
/* Add this as a static thread to CSRSRV */
CsrAddStaticServerThread(hThread, &ClientId, CsrThreadIsServerThread);
/* Get the Thread List Pointers */
ListHead = &CsrRootProcess->ThreadList;
NextEntry = ListHead->Flink;
/* Start looping the list */
while (NextEntry != ListHead)
{
/* Get the Thread */
ServerThread = CONTAINING_RECORD(NextEntry, CSR_THREAD, Link);
/* Start it up */
Status = NtResumeThread(ServerThread->ThreadHandle, NULL);
/* Is this a Server Thread? */
if (ServerThread->Flags & CsrThreadIsServerThread)
{
/* If so, then wait for it to initialize */
Status = NtWaitForSingleObject(hRequestEvent, FALSE, NULL);
ASSERT(NT_SUCCESS(Status));
}
/* Next thread */
NextEntry = NextEntry->Flink;
}
/* We don't need this anymore */
NtClose(hRequestEvent);
}
}
}
/* Return */
return Status;
}
PBASE_STATIC_SERVER_DATA BaseStaticServerData; PBASE_STATIC_SERVER_DATA BaseStaticServerData;
NTSTATUS NTSTATUS
@ -685,8 +809,7 @@ BasepFakeStaticServerData(VOID)
} }
NTSTATUS WINAPI NTSTATUS WINAPI
CsrpHandleConnectionRequest (PPORT_MESSAGE Request, CsrpHandleConnectionRequest (PPORT_MESSAGE Request)
IN HANDLE hApiListenPort)
{ {
NTSTATUS Status; NTSTATUS Status;
HANDLE ServerPort = NULL, ServerThread = NULL; HANDLE ServerPort = NULL, ServerThread = NULL;
@ -780,7 +903,7 @@ CsrpHandleConnectionRequest (PPORT_MESSAGE Request,
0, 0,
0, 0,
(PTHREAD_START_ROUTINE)ClientConnectionThread, (PTHREAD_START_ROUTINE)ClientConnectionThread,
ServerPort, NULL,
& ServerThread, & ServerThread,
&ClientId); &ClientId);
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
@ -849,74 +972,130 @@ CsrConnectToUser(VOID)
VOID VOID
WINAPI WINAPI
ClientConnectionThread(HANDLE ServerPort) ClientConnectionThread(IN PVOID Parameter)
{ {
PTEB Teb = NtCurrentTeb();
LARGE_INTEGER TimeOut;
NTSTATUS Status; NTSTATUS Status;
BYTE RawRequest[LPC_MAX_DATA_LENGTH]; BYTE RawRequest[LPC_MAX_DATA_LENGTH];
PCSR_API_MESSAGE Request = (PCSR_API_MESSAGE)RawRequest; PCSR_API_MESSAGE Request = (PCSR_API_MESSAGE)RawRequest;
PCSR_API_MESSAGE Reply; PCSR_API_MESSAGE Reply;
PCSR_PROCESS ProcessData; PCSR_PROCESS ProcessData;
PCSR_THREAD ServerThread; PCSR_THREAD ServerThread;
ULONG MessageType;
DPRINT("CSR: %s called\n", __FUNCTION__); DPRINT("CSR: %s called\n", __FUNCTION__);
/* Setup LPC loop port and message */
Reply = NULL;
// ReplyPort = CsrApiPort;
/* Connect to user32 */ /* Connect to user32 */
while (!CsrConnectToUser()) while (!CsrConnectToUser())
{ {
/* Set up the timeout for the connect (30 seconds) */
TimeOut.QuadPart = -30 * 1000 * 1000 * 10;
/* Keep trying until we get a response */ /* Keep trying until we get a response */
NtCurrentTeb()->Win32ClientInfo[0] = 0; Teb->Win32ClientInfo[0] = 0;
//NtDelayExecution(FALSE, &TimeOut); NtDelayExecution(FALSE, &TimeOut);
} }
/* Reply must be NULL at the first call to NtReplyWaitReceivePort */ /* Get our thread */
ServerThread = NtCurrentTeb()->CsrClientThread; ServerThread = Teb->CsrClientThread;
Reply = NULL;
/* Loop and reply/wait for a new message */ /* If we got an event... */
for (;;) if (Parameter)
{ {
/* Set it, to let stuff waiting on us load */
Status = NtSetEvent((HANDLE)Parameter, NULL);
ASSERT(NT_SUCCESS(Status));
/* Increase the Thread Counts */
//_InterlockedIncrement(&CsrpStaticThreadCount);
//_InterlockedIncrement(&CsrpDynamicThreadTotal);
}
/* Now start the loop */
while (TRUE)
{
/* Make sure the real CID is set */
Teb->RealClientId = Teb->ClientId;
/* Debug check */
if (Teb->CountOfOwnedCriticalSections)
{
DPRINT1("CSRSRV: FATAL ERROR. CsrThread is Idle while holding %lu critical sections\n",
Teb->CountOfOwnedCriticalSections);
DPRINT1("CSRSRV: Last Receive Message %lx ReplyMessage %lx\n",
&Request, Reply);
DbgBreakPoint();
}
/* Send the reply and wait for a new request */ /* Send the reply and wait for a new request */
Status = NtReplyWaitReceivePort(hApiPort, Status = NtReplyWaitReceivePort(hApiPort,
0, 0,
&Reply->Header, &Reply->Header,
&Request->Header); &Request->Header);
/* Client died, continue */ /* Check if we didn't get success */
if (Status == STATUS_INVALID_CID) if (Status != STATUS_SUCCESS)
{ {
Reply = NULL; /* Was it a failure or another success code? */
continue; if (!NT_SUCCESS(Status))
{
/* Check for specific status cases */
if ((Status != STATUS_INVALID_CID) &&
(Status != STATUS_UNSUCCESSFUL))// &&
// ((Status == STATUS_INVALID_HANDLE) || (ReplyPort == CsrApiPort)))
{
/* Notify the debugger */
DPRINT1("CSRSS: ReceivePort failed - Status == %X\n", Status);
//DPRINT1("CSRSS: ReplyPortHandle %lx CsrApiPort %lx\n", ReplyPort, CsrApiPort);
}
/* We failed big time, so start out fresh */
Reply = NULL;
//ReplyPort = CsrApiPort;
continue;
}
else
{
/* A bizare "success" code, just try again */
DPRINT1("NtReplyWaitReceivePort returned \"success\" status 0x%x\n", Status);
continue;
}
} }
if (!NT_SUCCESS(Status)) /* Use whatever Client ID we got */
{ Teb->RealClientId = Request->Header.ClientId;
DPRINT1("NtReplyWaitReceivePort failed: %lx\n", Status);
break; /* Get the Message Type */
} MessageType = Request->Header.u2.s2.Type;
/* If the connection was closed, handle that */ /* If the connection was closed, handle that */
if (Request->Header.u2.s2.Type == LPC_PORT_CLOSED) if (MessageType == LPC_PORT_CLOSED)
{ {
DPRINT("Port died, oh well\n"); DPRINT("Port died, oh well\n");
CsrFreeProcessData( Request->Header.ClientId.UniqueProcess ); CsrFreeProcessData( Request->Header.ClientId.UniqueProcess );
break; break;
} }
if (Request->Header.u2.s2.Type == LPC_CONNECTION_REQUEST) if (MessageType == LPC_CONNECTION_REQUEST)
{ {
CsrpHandleConnectionRequest((PPORT_MESSAGE)Request, ServerPort); CsrpHandleConnectionRequest((PPORT_MESSAGE)Request);
Reply = NULL; Reply = NULL;
continue; continue;
} }
if (Request->Header.u2.s2.Type == LPC_CLIENT_DIED) if (MessageType == LPC_CLIENT_DIED)
{ {
DPRINT("Client died, oh well\n"); DPRINT("Client died, oh well\n");
Reply = NULL; Reply = NULL;
continue; continue;
} }
if ((Request->Header.u2.s2.Type != LPC_ERROR_EVENT) && if ((MessageType != LPC_ERROR_EVENT) &&
(Request->Header.u2.s2.Type != LPC_REQUEST)) (MessageType != LPC_REQUEST))
{ {
DPRINT1("CSR: received message %d\n", Request->Header.u2.s2.Type); DPRINT1("CSR: received message %d\n", Request->Header.u2.s2.Type);
Reply = NULL; Reply = NULL;
@ -932,7 +1111,7 @@ ClientConnectionThread(HANDLE ServerPort)
if (ProcessData == NULL) if (ProcessData == NULL)
{ {
DPRINT1("Message %d: Unable to find data for process 0x%x\n", DPRINT1("Message %d: Unable to find data for process 0x%x\n",
Request->Header.u2.s2.Type, MessageType,
Request->Header.ClientId.UniqueProcess); Request->Header.ClientId.UniqueProcess);
break; break;
} }
@ -944,7 +1123,7 @@ ClientConnectionThread(HANDLE ServerPort)
} }
/* Check if we got a hard error */ /* Check if we got a hard error */
if (Request->Header.u2.s2.Type == LPC_ERROR_EVENT) if (MessageType == LPC_ERROR_EVENT)
{ {
/* Call the Handler */ /* Call the Handler */
CsrHandleHardError(ProcessData, (PHARDERROR_MSG)Request); CsrHandleHardError(ProcessData, (PHARDERROR_MSG)Request);
@ -972,7 +1151,9 @@ ClientConnectionThread(HANDLE ServerPort)
// NtClose(ServerPort); // NtClose(ServerPort);
DPRINT("CSR: %s done\n", __FUNCTION__); DPRINT("CSR: %s done\n", __FUNCTION__);
RtlExitUserThread(STATUS_SUCCESS); /* We're out of the loop for some reason, terminate! */
NtTerminateThread(NtCurrentThread(), Status);
//return Status;
} }
/*++ /*++

View file

@ -176,57 +176,6 @@ CSRSS_API_DEFINITION NativeDefinitions[] =
{ 0, 0, NULL } { 0, 0, NULL }
}; };
static NTSTATUS WINAPI
CsrpCreateListenPort (IN LPWSTR Name,
IN OUT PHANDLE Port,
IN PTHREAD_START_ROUTINE ListenThread)
{
NTSTATUS Status = STATUS_SUCCESS;
OBJECT_ATTRIBUTES PortAttributes;
UNICODE_STRING PortName;
HANDLE ServerThread;
CLIENT_ID ClientId;
DPRINT("CSR: %s called\n", __FUNCTION__);
RtlInitUnicodeString (& PortName, Name);
InitializeObjectAttributes (& PortAttributes,
& PortName,
0,
NULL,
NULL);
Status = NtCreatePort ( Port,
& PortAttributes,
sizeof(SB_CONNECTION_INFO),
sizeof(SB_API_MSG),
32 * sizeof(SB_API_MSG));
if(!NT_SUCCESS(Status))
{
DPRINT1("CSR: %s: NtCreatePort failed (Status=%08lx)\n",
__FUNCTION__, Status);
return Status;
}
Status = RtlCreateUserThread(NtCurrentProcess(),
NULL,
TRUE,
0,
0,
0,
(PTHREAD_START_ROUTINE) ListenThread,
*Port,
&ServerThread,
&ClientId);
if (ListenThread == (PVOID)ClientConnectionThread)
{
CsrAddStaticServerThread(ServerThread, &ClientId, 0);
}
NtResumeThread(ServerThread, NULL);
NtClose(ServerThread);
return Status;
}
/* === INIT ROUTINES === */ /* === INIT ROUTINES === */
VOID VOID
@ -1150,10 +1099,13 @@ CsrServerInitialization(IN ULONG ArgumentCount,
DPRINT1("CSRSRV failed in %s with status %lx\n", "CsrApiRegisterDefinitions", Status); DPRINT1("CSRSRV failed in %s with status %lx\n", "CsrApiRegisterDefinitions", Status);
} }
Status = CsrpCreateListenPort(L"\\Windows\\ApiPort", &hApiPort, (PTHREAD_START_ROUTINE)ClientConnectionThread); /* Now initialize our API Port */
Status = CsrApiPortInitialize();
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
DPRINT1("CSRSRV failed in %s with status %lx\n", "CsrpCreateApiPort", Status); DPRINT1("CSRSRV:%s: CsrApiPortInitialize failed (Status=%08lx)\n",
__FUNCTION__, Status);
return Status;
} }
Status = CsrpInitWin32Csr(); Status = CsrpInitWin32Csr();

View file

@ -235,6 +235,12 @@ extern HANDLE CsrSbApiPort;
extern LIST_ENTRY CsrThreadHashTable[256]; extern LIST_ENTRY CsrThreadHashTable[256];
extern PCSR_PROCESS CsrRootProcess; extern PCSR_PROCESS CsrRootProcess;
extern RTL_CRITICAL_SECTION ProcessDataLock, CsrWaitListsLock; extern RTL_CRITICAL_SECTION ProcessDataLock, CsrWaitListsLock;
extern UNICODE_STRING CsrDirectoryName;
extern ULONG CsrDebug;
NTSTATUS
NTAPI
CsrApiPortInitialize(VOID);
BOOLEAN BOOLEAN
NTAPI NTAPI