[SERVICES] Merge ScmControlService() and ScmSendStartCommand() together (#7392)

In addition:

- Acquire ControlServiceCriticalSection just before doing the pipe
  operations, and release it just after.

- SAL2-annotate ScmControlService().

- Re-order the ScmControlService() parameters in a more natural way
  (image comm pipe, service name, control code; then: arguments for
  the control command).

- Improve some DPRINTs.
This commit is contained in:
Hermès Bélusca-Maïto 2018-02-25 17:57:54 +01:00
parent 38f21f93e9
commit 0f7b021fe6
No known key found for this signature in database
GPG key ID: 3B2539C65E7B93D0
3 changed files with 101 additions and 257 deletions

View file

@ -1414,206 +1414,37 @@ ScmGetBootAndSystemDriverState(VOID)
* The service passed must always be referenced instead. * The service passed must always be referenced instead.
*/ */
DWORD DWORD
ScmControlService(HANDLE hControlPipe, ScmControlServiceEx(
PWSTR pServiceName, _In_ HANDLE hControlPipe,
SERVICE_STATUS_HANDLE hServiceStatus, _In_ PCWSTR pServiceName,
DWORD dwControl) _In_ DWORD dwControl,
_In_ SERVICE_STATUS_HANDLE hServiceStatus,
_In_opt_ DWORD dwServiceTag,
_In_opt_ DWORD argc,
_In_reads_opt_(argc) PWSTR* argv)
{ {
PSCM_CONTROL_PACKET ControlPacket;
SCM_REPLY_PACKET ReplyPacket;
DWORD dwWriteCount = 0;
DWORD dwReadCount = 0;
DWORD PacketSize;
PWSTR Ptr;
DWORD dwError = ERROR_SUCCESS; DWORD dwError = ERROR_SUCCESS;
BOOL bResult; BOOL bResult;
OVERLAPPED Overlapped = {0};
DPRINT("ScmControlService(%S, %d) called\n", pServiceName, dwControl);
/* Acquire the service control critical section, to synchronize requests */
EnterCriticalSection(&ControlServiceCriticalSection);
/* Calculate the total length of the start command line */
PacketSize = sizeof(SCM_CONTROL_PACKET);
PacketSize += (DWORD)((wcslen(pServiceName) + 1) * sizeof(WCHAR));
ControlPacket = HeapAlloc(GetProcessHeap(),
HEAP_ZERO_MEMORY,
PacketSize);
if (ControlPacket == NULL)
{
LeaveCriticalSection(&ControlServiceCriticalSection);
return ERROR_NOT_ENOUGH_MEMORY;
}
ControlPacket->dwSize = PacketSize;
ControlPacket->dwControl = dwControl;
ControlPacket->hServiceStatus = hServiceStatus;
ControlPacket->dwServiceNameOffset = sizeof(SCM_CONTROL_PACKET);
Ptr = (PWSTR)((PBYTE)ControlPacket + ControlPacket->dwServiceNameOffset);
wcscpy(Ptr, pServiceName);
ControlPacket->dwArgumentsCount = 0;
ControlPacket->dwArgumentsOffset = 0;
bResult = WriteFile(hControlPipe,
ControlPacket,
PacketSize,
&dwWriteCount,
&Overlapped);
if (bResult == FALSE)
{
DPRINT1("WriteFile(%S, %d) returned FALSE\n", pServiceName, dwControl);
dwError = GetLastError();
if (dwError == ERROR_IO_PENDING)
{
DPRINT("(%S, %d) dwError: ERROR_IO_PENDING\n", pServiceName, dwControl);
dwError = WaitForSingleObject(hControlPipe,
PipeTimeout);
DPRINT("WaitForSingleObject(%S, %d) returned %lu\n", pServiceName, dwControl, dwError);
if (dwError == WAIT_TIMEOUT)
{
DPRINT1("WaitForSingleObject(%S, %d) timed out\n", pServiceName, dwControl, dwError);
bResult = CancelIo(hControlPipe);
if (bResult == FALSE)
{
DPRINT1("CancelIo(%S, %d) failed (Error: %lu)\n", pServiceName, dwControl, GetLastError());
}
dwError = ERROR_SERVICE_REQUEST_TIMEOUT;
goto Done;
}
else if (dwError == WAIT_OBJECT_0)
{
bResult = GetOverlappedResult(hControlPipe,
&Overlapped,
&dwWriteCount,
TRUE);
if (bResult == FALSE)
{
dwError = GetLastError();
DPRINT1("GetOverlappedResult(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done;
}
}
}
else
{
DPRINT1("WriteFile(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done;
}
}
/* Read the reply */
Overlapped.hEvent = (HANDLE) NULL;
bResult = ReadFile(hControlPipe,
&ReplyPacket,
sizeof(SCM_REPLY_PACKET),
&dwReadCount,
&Overlapped);
if (bResult == FALSE)
{
DPRINT1("ReadFile(%S, %d) returned FALSE\n", pServiceName, dwControl);
dwError = GetLastError();
if (dwError == ERROR_IO_PENDING)
{
DPRINT("(%S, %d) dwError: ERROR_IO_PENDING\n", pServiceName, dwControl);
dwError = WaitForSingleObject(hControlPipe,
PipeTimeout);
DPRINT("WaitForSingleObject(%S, %d) returned %lu\n", pServiceName, dwControl, dwError);
if (dwError == WAIT_TIMEOUT)
{
DPRINT1("WaitForSingleObject(%S, %d) timed out\n", pServiceName, dwControl, dwError);
bResult = CancelIo(hControlPipe);
if (bResult == FALSE)
{
DPRINT1("CancelIo(%S, %d) failed (Error: %lu)\n", pServiceName, dwControl, GetLastError());
}
dwError = ERROR_SERVICE_REQUEST_TIMEOUT;
goto Done;
}
else if (dwError == WAIT_OBJECT_0)
{
bResult = GetOverlappedResult(hControlPipe,
&Overlapped,
&dwReadCount,
TRUE);
if (bResult == FALSE)
{
dwError = GetLastError();
DPRINT1("GetOverlappedResult(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done;
}
}
}
else
{
DPRINT1("ReadFile(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done;
}
}
Done:
/* Release the control packet */
HeapFree(GetProcessHeap(),
0,
ControlPacket);
if (dwReadCount == sizeof(SCM_REPLY_PACKET))
{
dwError = ReplyPacket.dwError;
}
LeaveCriticalSection(&ControlServiceCriticalSection);
DPRINT("ScmControlService(%S, %d) done\n", pServiceName, dwControl);
return dwError;
}
static DWORD
ScmSendStartCommand(PSERVICE Service,
DWORD argc,
LPWSTR* argv)
{
DWORD dwError = ERROR_SUCCESS;
PSCM_CONTROL_PACKET ControlPacket; PSCM_CONTROL_PACKET ControlPacket;
SCM_REPLY_PACKET ReplyPacket; SCM_REPLY_PACKET ReplyPacket;
DWORD PacketSize; DWORD PacketSize;
DWORD i; DWORD i;
PWSTR Ptr; PWSTR Ptr;
PWSTR *pOffPtr;
PWSTR pArgPtr;
BOOL bResult;
DWORD dwWriteCount = 0; DWORD dwWriteCount = 0;
DWORD dwReadCount = 0; DWORD dwReadCount = 0;
OVERLAPPED Overlapped = {0}; OVERLAPPED Overlapped = {0};
DPRINT("ScmSendStartCommand() called\n"); DPRINT("ScmControlService(%S, %d) called\n", pServiceName, dwControl);
/* Calculate the total length of the start command line */ /* Calculate the total size of the control packet:
* initial structure, the start command line, and the argument vector */
PacketSize = sizeof(SCM_CONTROL_PACKET); PacketSize = sizeof(SCM_CONTROL_PACKET);
PacketSize += (DWORD)((wcslen(Service->lpServiceName) + 1) * sizeof(WCHAR)); PacketSize += (DWORD)((wcslen(pServiceName) + 1) * sizeof(WCHAR));
/* /*
* Calculate the required packet size for the start argument vector 'argv', * Calculate the required packet size for the start argument vector 'argv',
* composed of the list of pointer offsets, followed by UNICODE strings. * composed of the pointer offsets list, followed by UNICODE strings.
* The strings are stored continuously after the vector of offsets, with * The strings are stored successively after the offsets vector, with
* the offsets being relative to the beginning of the vector, as in the * the offsets being relative to the beginning of the vector, as in the
* following layout (with N == argc): * following layout (with N == argc):
* [argOff(0)]...[argOff(N-1)][str(0)]...[str(N-1)] . * [argOff(0)]...[argOff(N-1)][str(0)]...[str(N-1)] .
@ -1631,22 +1462,20 @@ ScmSendStartCommand(PSERVICE Service,
} }
} }
/* Allocate a control packet */ /* Allocate the control packet */
ControlPacket = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, PacketSize); ControlPacket = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, PacketSize);
if (ControlPacket == NULL) if (!ControlPacket)
return ERROR_NOT_ENOUGH_MEMORY; return ERROR_NOT_ENOUGH_MEMORY;
ControlPacket->dwSize = PacketSize; ControlPacket->dwSize = PacketSize;
ControlPacket->dwControl = (Service->Status.dwServiceType & SERVICE_WIN32_OWN_PROCESS) ControlPacket->dwControl = dwControl;
? SERVICE_CONTROL_START_OWN ControlPacket->hServiceStatus = hServiceStatus;
: SERVICE_CONTROL_START_SHARE; ControlPacket->dwServiceTag = dwServiceTag;
ControlPacket->hServiceStatus = (SERVICE_STATUS_HANDLE)Service;
ControlPacket->dwServiceTag = Service->dwServiceTag;
/* Copy the start command line */ /* Copy the start command line */
ControlPacket->dwServiceNameOffset = sizeof(SCM_CONTROL_PACKET); ControlPacket->dwServiceNameOffset = sizeof(SCM_CONTROL_PACKET);
Ptr = (PWSTR)((ULONG_PTR)ControlPacket + ControlPacket->dwServiceNameOffset); Ptr = (PWSTR)((ULONG_PTR)ControlPacket + ControlPacket->dwServiceNameOffset);
wcscpy(Ptr, Service->lpServiceName); wcscpy(Ptr, pServiceName);
ControlPacket->dwArgumentsCount = 0; ControlPacket->dwArgumentsCount = 0;
ControlPacket->dwArgumentsOffset = 0; ControlPacket->dwArgumentsOffset = 0;
@ -1654,7 +1483,9 @@ ScmSendStartCommand(PSERVICE Service,
/* Copy the argument vector */ /* Copy the argument vector */
if (argc > 0 && argv != NULL) if (argc > 0 && argv != NULL)
{ {
Ptr += wcslen(Service->lpServiceName) + 1; PWSTR *pOffPtr, pArgPtr;
Ptr += wcslen(pServiceName) + 1;
pOffPtr = (PWSTR*)ALIGN_UP_POINTER(Ptr, PWSTR); pOffPtr = (PWSTR*)ALIGN_UP_POINTER(Ptr, PWSTR);
pArgPtr = (PWSTR)((ULONG_PTR)pOffPtr + argc * sizeof(PWSTR)); pArgPtr = (PWSTR)((ULONG_PTR)pOffPtr + argc * sizeof(PWSTR));
@ -1673,127 +1504,134 @@ ScmSendStartCommand(PSERVICE Service,
} }
} }
bResult = WriteFile(Service->lpImage->hControlPipe, /* Acquire the service control critical section, to synchronize requests */
EnterCriticalSection(&ControlServiceCriticalSection);
bResult = WriteFile(hControlPipe,
ControlPacket, ControlPacket,
PacketSize, PacketSize,
&dwWriteCount, &dwWriteCount,
&Overlapped); &Overlapped);
if (bResult == FALSE) if (!bResult)
{ {
DPRINT("WriteFile() returned FALSE\n");
dwError = GetLastError(); dwError = GetLastError();
if (dwError == ERROR_IO_PENDING) if (dwError == ERROR_IO_PENDING)
{ {
DPRINT("dwError: ERROR_IO_PENDING\n"); DPRINT("WriteFile(%S, %d) returned ERROR_IO_PENDING\n", pServiceName, dwControl);
dwError = WaitForSingleObject(Service->lpImage->hControlPipe, dwError = WaitForSingleObject(hControlPipe,
PipeTimeout); PipeTimeout);
DPRINT("WaitForSingleObject() returned %lu\n", dwError); DPRINT("WaitForSingleObject(%S, %d) returned %lu\n", pServiceName, dwControl, dwError);
if (dwError == WAIT_TIMEOUT) if (dwError == WAIT_TIMEOUT)
{ {
bResult = CancelIo(Service->lpImage->hControlPipe); DPRINT1("WaitForSingleObject(%S, %d) timed out\n", pServiceName, dwControl);
if (bResult == FALSE) bResult = CancelIo(hControlPipe);
{ if (!bResult)
DPRINT1("CancelIo() failed (Error: %lu)\n", GetLastError()); DPRINT1("CancelIo(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, GetLastError());
}
dwError = ERROR_SERVICE_REQUEST_TIMEOUT; dwError = ERROR_SERVICE_REQUEST_TIMEOUT;
goto Done; goto Done;
} }
else if (dwError == WAIT_OBJECT_0) else if (dwError == WAIT_OBJECT_0)
{ {
bResult = GetOverlappedResult(Service->lpImage->hControlPipe, bResult = GetOverlappedResult(hControlPipe,
&Overlapped, &Overlapped,
&dwWriteCount, &dwWriteCount,
TRUE); TRUE);
if (bResult == FALSE) if (!bResult)
{ {
dwError = GetLastError(); dwError = GetLastError();
DPRINT1("GetOverlappedResult() failed (Error %lu)\n", dwError); DPRINT1("GetOverlappedResult(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done; goto Done;
} }
} }
} }
else else
{ {
DPRINT1("WriteFile() failed (Error %lu)\n", dwError); DPRINT1("WriteFile(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done; goto Done;
} }
} }
/* Read the reply */ /* Read the reply */
Overlapped.hEvent = (HANDLE) NULL; Overlapped.hEvent = NULL;
bResult = ReadFile(Service->lpImage->hControlPipe, bResult = ReadFile(hControlPipe,
&ReplyPacket, &ReplyPacket,
sizeof(SCM_REPLY_PACKET), sizeof(ReplyPacket),
&dwReadCount, &dwReadCount,
&Overlapped); &Overlapped);
if (bResult == FALSE) if (!bResult)
{ {
DPRINT("ReadFile() returned FALSE\n");
dwError = GetLastError(); dwError = GetLastError();
if (dwError == ERROR_IO_PENDING) if (dwError == ERROR_IO_PENDING)
{ {
DPRINT("dwError: ERROR_IO_PENDING\n"); DPRINT("ReadFile(%S, %d) returned ERROR_IO_PENDING\n", pServiceName, dwControl);
dwError = WaitForSingleObject(Service->lpImage->hControlPipe, dwError = WaitForSingleObject(hControlPipe,
PipeTimeout); PipeTimeout);
DPRINT("WaitForSingleObject() returned %lu\n", dwError); DPRINT("WaitForSingleObject(%S, %d) returned %lu\n", pServiceName, dwControl, dwError);
if (dwError == WAIT_TIMEOUT) if (dwError == WAIT_TIMEOUT)
{ {
bResult = CancelIo(Service->lpImage->hControlPipe); DPRINT1("WaitForSingleObject(%S, %d) timed out\n", pServiceName, dwControl);
if (bResult == FALSE) bResult = CancelIo(hControlPipe);
{ if (!bResult)
DPRINT1("CancelIo() failed (Error: %lu)\n", GetLastError()); DPRINT1("CancelIo(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, GetLastError());
}
dwError = ERROR_SERVICE_REQUEST_TIMEOUT; dwError = ERROR_SERVICE_REQUEST_TIMEOUT;
goto Done; goto Done;
} }
else if (dwError == WAIT_OBJECT_0) else if (dwError == WAIT_OBJECT_0)
{ {
bResult = GetOverlappedResult(Service->lpImage->hControlPipe, bResult = GetOverlappedResult(hControlPipe,
&Overlapped, &Overlapped,
&dwReadCount, &dwReadCount,
TRUE); TRUE);
if (bResult == FALSE) if (!bResult)
{ {
dwError = GetLastError(); dwError = GetLastError();
DPRINT1("GetOverlappedResult() failed (Error %lu)\n", dwError); DPRINT1("GetOverlappedResult(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done; goto Done;
} }
} }
} }
else else
{ {
DPRINT1("ReadFile() failed (Error %lu)\n", dwError); DPRINT1("ReadFile(%S, %d) failed (Error %lu)\n", pServiceName, dwControl, dwError);
goto Done; goto Done;
} }
} }
Done: Done:
/* Release the control packet */ /* Release the service control critical section */
HeapFree(GetProcessHeap(), LeaveCriticalSection(&ControlServiceCriticalSection);
0,
ControlPacket);
if (dwReadCount == sizeof(SCM_REPLY_PACKET)) /* Free the control packet */
{ HeapFree(GetProcessHeap(), 0, ControlPacket);
if (dwReadCount == sizeof(ReplyPacket))
dwError = ReplyPacket.dwError; dwError = ReplyPacket.dwError;
}
DPRINT("ScmSendStartCommand() done\n");
DPRINT("ScmControlService(%S, %d) done (Error %lu)\n", pServiceName, dwControl, dwError);
return dwError; return dwError;
} }
DWORD
ScmControlService(
_In_ HANDLE hControlPipe,
_In_ PCWSTR pServiceName,
_In_ DWORD dwControl,
_In_ SERVICE_STATUS_HANDLE hServiceStatus)
{
return ScmControlServiceEx(hControlPipe,
pServiceName,
dwControl,
hServiceStatus,
0, 0, NULL);
}
static DWORD static DWORD
ScmWaitForServiceConnect(PSERVICE Service) ScmWaitForServiceConnect(PSERVICE Service)
@ -1811,8 +1649,6 @@ ScmWaitForServiceConnect(PSERVICE Service)
DPRINT("ScmWaitForServiceConnect()\n"); DPRINT("ScmWaitForServiceConnect()\n");
Overlapped.hEvent = (HANDLE)NULL;
bResult = ConnectNamedPipe(Service->lpImage->hControlPipe, bResult = ConnectNamedPipe(Service->lpImage->hControlPipe,
&Overlapped); &Overlapped);
if (bResult == FALSE) if (bResult == FALSE)
@ -1876,7 +1712,7 @@ ScmWaitForServiceConnect(PSERVICE Service)
DPRINT("Control pipe connected\n"); DPRINT("Control pipe connected\n");
Overlapped.hEvent = (HANDLE) NULL; Overlapped.hEvent = NULL;
/* Read the process id from pipe */ /* Read the process id from pipe */
bResult = ReadFile(Service->lpImage->hControlPipe, bResult = ReadFile(Service->lpImage->hControlPipe,
@ -1987,12 +1823,9 @@ ScmStartUserModeService(PSERVICE Service,
DPRINT("ScmStartUserModeService(%p)\n", Service); DPRINT("ScmStartUserModeService(%p)\n", Service);
/* If the image is already running ... */ /* If the image is already running, just send a start command */
if (Service->lpImage->dwImageRunCount > 1) if (Service->lpImage->dwImageRunCount > 1)
{ goto Quit;
/* ... just send a start command */
return ScmSendStartCommand(Service, argc, argv);
}
/* Otherwise start its process */ /* Otherwise start its process */
ZeroMemory(&StartupInfo, sizeof(StartupInfo)); ZeroMemory(&StartupInfo, sizeof(StartupInfo));
@ -2115,8 +1948,15 @@ ScmStartUserModeService(PSERVICE Service,
return dwError; return dwError;
} }
/* Send the start command */ Quit:
return ScmSendStartCommand(Service, argc, argv); /* Send the start command and return */
return ScmControlServiceEx(Service->lpImage->hControlPipe,
Service->lpServiceName,
(Service->Status.dwServiceType & SERVICE_WIN32_OWN_PROCESS)
? SERVICE_CONTROL_START_OWN : SERVICE_CONTROL_START_SHARE,
(SERVICE_STATUS_HANDLE)Service,
Service->dwServiceTag,
argc, argv);
} }
@ -2508,8 +2348,8 @@ ScmAutoShutdownServices(VOID)
DPRINT("Shutdown service: %S\n", CurrentService->lpServiceName); DPRINT("Shutdown service: %S\n", CurrentService->lpServiceName);
ScmControlService(CurrentService->lpImage->hControlPipe, ScmControlService(CurrentService->lpImage->hControlPipe,
CurrentService->lpServiceName, CurrentService->lpServiceName,
(SERVICE_STATUS_HANDLE)CurrentService, SERVICE_CONTROL_SHUTDOWN,
SERVICE_CONTROL_SHUTDOWN); (SERVICE_STATUS_HANDLE)CurrentService);
} }
ServiceEntry = ServiceEntry->Flink; ServiceEntry = ServiceEntry->Flink;

View file

@ -1187,8 +1187,8 @@ RControlService(
/* Send control code to the service */ /* Send control code to the service */
dwError = ScmControlService(lpService->lpImage->hControlPipe, dwError = ScmControlService(lpService->lpImage->hControlPipe,
lpService->lpServiceName, lpService->lpServiceName,
(SERVICE_STATUS_HANDLE)lpService, dwControl,
dwControl); (SERVICE_STATUS_HANDLE)lpService);
/* Return service status information */ /* Return service status information */
RtlCopyMemory(lpServiceStatus, RtlCopyMemory(lpServiceStatus,
@ -1626,8 +1626,8 @@ ScmStopThread(PVOID pParam)
DPRINT("Stopping the dispatcher thread for service %S\n", lpService->lpServiceName); DPRINT("Stopping the dispatcher thread for service %S\n", lpService->lpServiceName);
ScmControlService(lpService->lpImage->hControlPipe, ScmControlService(lpService->lpImage->hControlPipe,
L"", L"",
(SERVICE_STATUS_HANDLE)lpService, SERVICE_CONTROL_STOP,
SERVICE_CONTROL_STOP); (SERVICE_STATUS_HANDLE)lpService);
} }
/* Lock the service database exclusively */ /* Lock the service database exclusively */

View file

@ -18,10 +18,12 @@
#include <winreg.h> #include <winreg.h>
#include <winuser.h> #include <winuser.h>
#include <netevent.h> #include <netevent.h>
#define NTOS_MODE_USER #define NTOS_MODE_USER
#include <ndk/setypes.h> #include <ndk/setypes.h>
#include <ndk/obfuncs.h> #include <ndk/obfuncs.h>
#include <ndk/rtlfuncs.h> #include <ndk/rtlfuncs.h>
#include <services/services.h> #include <services/services.h>
#include <svcctl_s.h> #include <svcctl_s.h>
@ -200,10 +202,12 @@ DWORD ScmCreateNewServiceRecord(LPCWSTR lpServiceName,
VOID ScmDeleteServiceRecord(PSERVICE lpService); VOID ScmDeleteServiceRecord(PSERVICE lpService);
DWORD ScmMarkServiceForDelete(PSERVICE pService); DWORD ScmMarkServiceForDelete(PSERVICE pService);
DWORD ScmControlService(HANDLE hControlPipe, DWORD
PWSTR pServiceName, ScmControlService(
SERVICE_STATUS_HANDLE hServiceStatus, _In_ HANDLE hControlPipe,
DWORD dwControl); _In_ PCWSTR pServiceName,
_In_ DWORD dwControl,
_In_ SERVICE_STATUS_HANDLE hServiceStatus);
BOOL ScmLockDatabaseExclusive(VOID); BOOL ScmLockDatabaseExclusive(VOID);
BOOL ScmLockDatabaseShared(VOID); BOOL ScmLockDatabaseShared(VOID);