Cleaned up the driver loading routines.

svn path=/trunk/; revision=3058
This commit is contained in:
Eric Kohl 2002-06-12 23:33:15 +00:00
parent 2137580260
commit 9d3870ce78

View file

@ -1,4 +1,4 @@
/* $Id: database.c,v 1.1 2002/06/07 20:09:56 ekohl Exp $ /* $Id: database.c,v 1.2 2002/06/12 23:33:15 ekohl Exp $
* *
* service control manager * service control manager
* *
@ -42,7 +42,7 @@
typedef struct _SERVICE_GROUP typedef struct _SERVICE_GROUP
{ {
LIST_ENTRY GroupListEntry; LIST_ENTRY GroupListEntry;
PWSTR GroupName; UNICODE_STRING GroupName;
BOOLEAN ServicesRunning; BOOLEAN ServicesRunning;
@ -52,17 +52,16 @@ typedef struct _SERVICE_GROUP
typedef struct _SERVICE typedef struct _SERVICE
{ {
LIST_ENTRY ServiceListEntry; LIST_ENTRY ServiceListEntry;
PWSTR ServiceName; UNICODE_STRING ServiceName;
PWSTR GroupName; UNICODE_STRING RegistryPath;
UNICODE_STRING ServiceGroup;
PWSTR ImagePath;
ULONG Start; ULONG Start;
ULONG Type; ULONG Type;
ULONG ErrorControl; ULONG ErrorControl;
ULONG Tag; ULONG Tag;
BOOLEAN ServiceRunning; // needed ?? BOOLEAN ServiceRunning;
} SERVICE, *PSERVICE; } SERVICE, *PSERVICE;
@ -70,13 +69,11 @@ typedef struct _SERVICE
/* GLOBALS *******************************************************************/ /* GLOBALS *******************************************************************/
LIST_ENTRY GroupListHead = {NULL, NULL}; LIST_ENTRY GroupListHead = {NULL, NULL};
LIST_ENTRY ServiceListHead = {NULL, NULL}; LIST_ENTRY ServiceListHead = {NULL, NULL};
/* FUNCTIONS *****************************************************************/ /* FUNCTIONS *****************************************************************/
static NTSTATUS STDCALL static NTSTATUS STDCALL
CreateGroupListRoutine(PWSTR ValueName, CreateGroupListRoutine(PWSTR ValueName,
ULONG ValueType, ULONG ValueType,
@ -95,23 +92,19 @@ CreateGroupListRoutine(PWSTR ValueName,
HEAP_ZERO_MEMORY, HEAP_ZERO_MEMORY,
sizeof(SERVICE_GROUP)); sizeof(SERVICE_GROUP));
if (Group == NULL) if (Group == NULL)
{
return(STATUS_INSUFFICIENT_RESOURCES); return(STATUS_INSUFFICIENT_RESOURCES);
}
if (!RtlCreateUnicodeString(&Group->GroupName,
Group->GroupName = (PWSTR)HeapAlloc(GetProcessHeap(), (PWSTR)ValueData))
HEAP_ZERO_MEMORY, {
ValueLength);
if (Group->GroupName == NULL)
return(STATUS_INSUFFICIENT_RESOURCES); return(STATUS_INSUFFICIENT_RESOURCES);
}
wcscpy(Group->GroupName,
(PWSTR)ValueData);
InsertTailList(&GroupListHead, InsertTailList(&GroupListHead,
&Group->GroupListEntry); &Group->GroupListEntry);
} }
return(STATUS_SUCCESS); return(STATUS_SUCCESS);
@ -122,50 +115,49 @@ static NTSTATUS STDCALL
CreateServiceListEntry(PUNICODE_STRING ServiceName) CreateServiceListEntry(PUNICODE_STRING ServiceName)
{ {
RTL_QUERY_REGISTRY_TABLE QueryTable[6]; RTL_QUERY_REGISTRY_TABLE QueryTable[6];
WCHAR ServiceGroupBuffer[MAX_PATH]; PSERVICE Service = NULL;
WCHAR ImagePathBuffer[MAX_PATH];
UNICODE_STRING ServiceGroup;
UNICODE_STRING ImagePath;
PSERVICE_GROUP Group;
PSERVICE Service;
NTSTATUS Status; NTSTATUS Status;
// PrintString("Service: '%wZ'\n", ServiceName); // PrintString("Service: '%wZ'\n", ServiceName);
Service = (PSERVICE)HeapAlloc(GetProcessHeap(), /* Allocate service entry */
HEAP_ZERO_MEMORY, Service = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY,
sizeof(SERVICE)); sizeof(SERVICE));
if (Service == NULL) if (Service == NULL)
{ {
PrintString(" - HeapAlloc() (1) failed\n");
return(STATUS_INSUFFICIENT_RESOURCES); return(STATUS_INSUFFICIENT_RESOURCES);
} }
Service->ServiceName = (PWSTR)HeapAlloc(GetProcessHeap(), /* Copy service name */
HEAP_ZERO_MEMORY, Service->ServiceName.Length = ServiceName->Length;
ServiceName->Length); Service->ServiceName.MaximumLength = ServiceName->Length + sizeof(WCHAR);
if (Service->ServiceName == NULL) Service->ServiceName.Buffer = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY,
Service->ServiceName.MaximumLength);
if (Service->ServiceName.Buffer == NULL)
{ {
PrintString(" - HeapAlloc() (2) failed\n"); HeapFree(GetProcessHeap(), 0, Service);
return(STATUS_INSUFFICIENT_RESOURCES); return(STATUS_INSUFFICIENT_RESOURCES);
} }
RtlCopyMemory(Service->ServiceName.Buffer,
ServiceName->Buffer,
ServiceName->Length);
Service->ServiceName.Buffer[ServiceName->Length / sizeof(WCHAR)] = 0;
wcscpy(Service->ServiceName, /* Build registry path */
ServiceName->Buffer); Service->RegistryPath.MaximumLength = MAX_PATH * sizeof(WCHAR);
Service->RegistryPath.Buffer = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY,
ServiceGroup.Length = 0;
ServiceGroup.MaximumLength = MAX_PATH * sizeof(WCHAR);
ServiceGroup.Buffer = ServiceGroupBuffer;
RtlZeroMemory(ServiceGroupBuffer,
MAX_PATH * sizeof(WCHAR)); MAX_PATH * sizeof(WCHAR));
if (Service->ServiceName.Buffer == NULL)
ImagePath.Length = 0; {
ImagePath.MaximumLength = MAX_PATH * sizeof(WCHAR); HeapFree(GetProcessHeap(), 0, Service->ServiceName.Buffer);
ImagePath.Buffer = ImagePathBuffer; HeapFree(GetProcessHeap(), 0, Service);
RtlZeroMemory(ImagePathBuffer, return(STATUS_INSUFFICIENT_RESOURCES);
MAX_PATH * sizeof(WCHAR)); }
wcscpy(Service->RegistryPath.Buffer,
L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\");
wcscat(Service->RegistryPath.Buffer,
Service->ServiceName.Buffer);
Service->RegistryPath.Length = wcslen(Service->RegistryPath.Buffer) * sizeof(WCHAR);
/* Get service data */ /* Get service data */
RtlZeroMemory(&QueryTable, RtlZeroMemory(&QueryTable,
@ -185,12 +177,7 @@ CreateServiceListEntry(PUNICODE_STRING ServiceName)
QueryTable[3].Name = L"Group"; QueryTable[3].Name = L"Group";
QueryTable[3].Flags = RTL_QUERY_REGISTRY_DIRECT; QueryTable[3].Flags = RTL_QUERY_REGISTRY_DIRECT;
QueryTable[3].EntryContext = &ServiceGroup; QueryTable[3].EntryContext = &Service->ServiceGroup;
QueryTable[4].Name = L"ImagePath";
QueryTable[4].Flags = RTL_QUERY_REGISTRY_DIRECT;
QueryTable[4].EntryContext = &ImagePath;
Status = RtlQueryRegistryValues(RTL_REGISTRY_SERVICES, Status = RtlQueryRegistryValues(RTL_REGISTRY_SERVICES,
ServiceName->Buffer, ServiceName->Buffer,
@ -200,61 +187,24 @@ CreateServiceListEntry(PUNICODE_STRING ServiceName)
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
PrintString("RtlQueryRegistryValues() failed (Status %lx)\n", Status); PrintString("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
RtlFreeUnicodeString(&Service->RegistryPath);
RtlFreeUnicodeString(&Service->ServiceName);
HeapFree(GetProcessHeap(), 0, Service);
return(Status); return(Status);
} }
/* Copy the service group name */ #if 0
if (ServiceGroup.Length > 0) PrintString("ServiceName: '%wZ'\n", &Service->ServiceName);
{ PrintString("RegistryPath: '%wZ'\n", &Service->RegistryPath);
Service->GroupName = (PWSTR)HeapAlloc(GetProcessHeap(), PrintString("ServiceGroup: '%wZ'\n", &Service->ServiceGroup);
HEAP_ZERO_MEMORY, PrintString("Start %lx Type %lx ErrorControl %lx\n",
ServiceGroup.Length + sizeof(WCHAR)); Service->Start, Service->Type, Service->ErrorControl);
if (Service->GroupName == NULL) #endif
{
PrintString(" - HeapAlloc() (3) failed\n");
return(STATUS_INSUFFICIENT_RESOURCES);
}
memcpy(Service->GroupName,
ServiceGroup.Buffer,
ServiceGroup.Length);
}
else
{
Service->GroupName = NULL;
}
/* Copy the image path */
if (ImagePath.Length > 0)
{
Service->ImagePath = (PWSTR)HeapAlloc(GetProcessHeap(),
HEAP_ZERO_MEMORY,
ImagePath.Length + sizeof(WCHAR));
if (Service->ImagePath == NULL)
{
PrintString(" - HeapAlloc() (4) failed\n");
return(STATUS_INSUFFICIENT_RESOURCES);
}
memcpy(Service->ImagePath,
ImagePath.Buffer,
ImagePath.Length);
}
else
{
Service->ImagePath = NULL;
}
// PrintString(" Type: %lx\n", Service->Type);
// PrintString(" Start: %lx\n", Service->Start);
// PrintString(" Group: '%wZ'\n", &ServiceGroup);
/* Append service entry */ /* Append service entry */
InsertTailList(&ServiceListHead, InsertTailList(&ServiceListHead,
&Service->ServiceListEntry); &Service->ServiceListEntry);
return(STATUS_SUCCESS); return(STATUS_SUCCESS);
} }
@ -268,8 +218,14 @@ ScmCreateServiceDataBase(VOID)
UNICODE_STRING ServicesKeyName; UNICODE_STRING ServicesKeyName;
UNICODE_STRING SubKeyName; UNICODE_STRING SubKeyName;
HKEY ServicesKey; HKEY ServicesKey;
NTSTATUS Status;
ULONG Index; ULONG Index;
NTSTATUS Status;
PKEY_BASIC_INFORMATION KeyInfo = NULL;
ULONG KeyInfoLength = 0;
ULONG ReturnedLength;
// PrintString("ScmCreateServiceDataBase() called\n");
/* Initialize basic variables */ /* Initialize basic variables */
InitializeListHead(&GroupListHead); InitializeListHead(&GroupListHead);
@ -291,7 +247,6 @@ ScmCreateServiceDataBase(VOID)
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
return(Status); return(Status);
RtlInitUnicodeString(&ServicesKeyName, RtlInitUnicodeString(&ServicesKeyName,
L"\\Registry\\Machine\\System\\CurrentControlSet\\Services"); L"\\Registry\\Machine\\System\\CurrentControlSet\\Services");
@ -308,25 +263,48 @@ ScmCreateServiceDataBase(VOID)
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
return(Status); return(Status);
SubKeyName.Length = 0; /* Allocate key info buffer */
SubKeyName.MaximumLength = MAX_PATH * sizeof(WCHAR); KeyInfoLength = sizeof(KEY_BASIC_INFORMATION) + MAX_PATH * sizeof(WCHAR);
SubKeyName.Buffer = NameBuffer; KeyInfo = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, KeyInfoLength);
if (KeyInfo == NULL)
{
NtClose(ServicesKey);
return(STATUS_INSUFFICIENT_RESOURCES);
}
Index = 0; Index = 0;
while (TRUE) while (TRUE)
{ {
Status = RtlpNtEnumerateSubKey(ServicesKey, Status = NtEnumerateKey(ServicesKey,
&SubKeyName,
Index, Index,
0); KeyBasicInformation,
KeyInfo,
KeyInfoLength,
&ReturnedLength);
if (NT_SUCCESS(Status))
{
if (KeyInfo->NameLength < MAX_PATH * sizeof(WCHAR))
{
SubKeyName.Length = KeyInfo->NameLength;
SubKeyName.MaximumLength = KeyInfo->NameLength + sizeof(WCHAR);
SubKeyName.Buffer = KeyInfo->Name;
SubKeyName.Buffer[SubKeyName.Length / sizeof(WCHAR)] = 0;
// PrintString("KeyName: '%wZ'\n", &SubKeyName);
Status = CreateServiceListEntry(&SubKeyName);
}
}
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
break; break;
CreateServiceListEntry(&SubKeyName);
Index++; Index++;
} }
HeapFree(GetProcessHeap(), 0, KeyInfo);
NtClose(ServicesKey);
// PrintString("ScmCreateServiceDataBase() done\n"); // PrintString("ScmCreateServiceDataBase() done\n");
return(STATUS_SUCCESS); return(STATUS_SUCCESS);
@ -340,34 +318,6 @@ ScmGetBootAndSystemDriverState(VOID)
} }
static NTSTATUS
ScmLoadDriver(PSERVICE Service)
{
WCHAR ServicePath[MAX_PATH];
UNICODE_STRING DriverPath;
// PrintString("ScmLoadDriver(%S) called\n", Service->ServiceName);
if (Service->ImagePath == NULL)
{
wcscpy(ServicePath, L"\\SystemRoot\\system32\\drivers\\");
wcscat(ServicePath, Service->ServiceName);
wcscat(ServicePath, L".sys");
}
else
{
wcscpy(ServicePath, L"\\SystemRoot\\");
wcscat(ServicePath, Service->ImagePath);
}
RtlInitUnicodeString(&DriverPath, ServicePath);
// PrintString(" DriverPath: '%wZ'\n", &DriverPath);
return(NtLoadDriver(&DriverPath));
}
static NTSTATUS static NTSTATUS
ScmStartService(PSERVICE Service) ScmStartService(PSERVICE Service)
{ {
@ -376,11 +326,9 @@ ScmStartService(PSERVICE Service)
STARTUPINFO StartupInfo; STARTUPINFO StartupInfo;
WCHAR CommandLine[MAX_PATH]; WCHAR CommandLine[MAX_PATH];
BOOL Result; BOOL Result;
#endif
PrintString("ScmStartService(%S) called\n", Service->ServiceName); PrintString("ScmStartService() called\n");
#if 0
GetSystemDirectoryW(CommandLine, MAX_PATH); GetSystemDirectoryW(CommandLine, MAX_PATH);
_tcscat(CommandLine, "\\"); _tcscat(CommandLine, "\\");
_tcscat(CommandLine, FileName); _tcscat(CommandLine, FileName);
@ -436,14 +384,14 @@ ScmAutoStartServices(VOID)
{ {
CurrentGroup = CONTAINING_RECORD(GroupEntry, SERVICE_GROUP, GroupListEntry); CurrentGroup = CONTAINING_RECORD(GroupEntry, SERVICE_GROUP, GroupListEntry);
// PrintString(" %S\n", CurrentGroup->GroupName); // PrintString("Group '%wZ'\n", &CurrentGroup->GroupName);
ServiceEntry = ServiceListHead.Flink; ServiceEntry = ServiceListHead.Flink;
while (ServiceEntry != &ServiceListHead) while (ServiceEntry != &ServiceListHead)
{ {
CurrentService = CONTAINING_RECORD(ServiceEntry, SERVICE, ServiceListEntry); CurrentService = CONTAINING_RECORD(ServiceEntry, SERVICE, ServiceListEntry);
if ((wcsicmp(CurrentGroup->GroupName, CurrentService->GroupName) == 0) && if ((RtlCompareUnicodeString(&CurrentGroup->GroupName, &CurrentService->ServiceGroup, TRUE) == 0) &&
(CurrentService->Start == SERVICE_AUTO_START)) (CurrentService->Start == SERVICE_AUTO_START))
{ {
if (CurrentService->Type == SERVICE_KERNEL_DRIVER || if (CurrentService->Type == SERVICE_KERNEL_DRIVER ||
@ -451,7 +399,8 @@ ScmAutoStartServices(VOID)
CurrentService->Type == SERVICE_RECOGNIZER_DRIVER) CurrentService->Type == SERVICE_RECOGNIZER_DRIVER)
{ {
/* Load driver */ /* Load driver */
Status = ScmLoadDriver(CurrentService); // PrintString(" Path: %wZ\n", &CurrentService->RegistryPath);
Status = NtLoadDriver(&CurrentService->RegistryPath);
} }
else else
{ {
@ -499,10 +448,8 @@ ScmAutoStartServices(VOID)
} }
ServiceEntry = ServiceEntry->Flink; ServiceEntry = ServiceEntry->Flink;
} }
GroupEntry = GroupEntry->Flink; GroupEntry = GroupEntry->Flink;
} }
} }
/* EOF */ /* EOF */