- NtLoadDriver APIs refactor/improvent:

* Move loading of a driver to the common (in future) routine IopLoadUnloadDriver (name taken from http://wasm.ru/forum/viewtopic.php?pid=166891).
 * Fix a bug: NtLoadDriver should always load drivers in a context of the system process, explanation here: http://www.osronline.com/showthread.cfm?link=114687
 * Reformat NtLoadDriver's code.

svn path=/trunk/; revision=29370
This commit is contained in:
Aleksey Bragin 2007-10-03 10:17:04 +00:00
parent 597b822c7e
commit 2ea98ca374
2 changed files with 116 additions and 63 deletions

View file

@ -349,6 +349,17 @@ typedef struct _OPEN_PACKET
//PIO_DRIVER_CREATE_CONTEXT DriverCreateContext; Vista only, needs ROS DDK Update
} OPEN_PACKET, *POPEN_PACKET;
//
// Parameters packet for Load/Unload work item's context
//
typedef struct _LOAD_UNLOAD_PARAMS
{
NTSTATUS Status;
PUNICODE_STRING ServiceName;
WORK_QUEUE_ITEM WorkItem;
KEVENT Event;
} LOAD_UNLOAD_PARAMS, *PLOAD_UNLOAD_PARAMS;
//
// List of Bus Type GUIDs
//

View file

@ -1465,6 +1465,165 @@ IoGetDriverObjectExtension(IN PDRIVER_OBJECT DriverObject,
return DriverExtensions + 1;
}
VOID NTAPI
IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams)
{
RTL_QUERY_REGISTRY_TABLE QueryTable[3];
UNICODE_STRING ImagePath;
UNICODE_STRING ServiceName;
NTSTATUS Status;
ULONG Type;
PDEVICE_NODE DeviceNode;
PLDR_DATA_TABLE_ENTRY ModuleObject;
PDRIVER_OBJECT DriverObject;
WCHAR *cur;
RtlInitUnicodeString(&ImagePath, NULL);
/*
* Get the service name from the registry key name.
*/
ASSERT(LoadParams->ServiceName->Length >= sizeof(WCHAR));
ServiceName = *LoadParams->ServiceName;
cur = LoadParams->ServiceName->Buffer +
(LoadParams->ServiceName->Length / sizeof(WCHAR)) - 1;
while (LoadParams->ServiceName->Buffer != cur)
{
if(*cur == L'\\')
{
ServiceName.Buffer = cur + 1;
ServiceName.Length = LoadParams->ServiceName->Length -
(USHORT)((ULONG_PTR)ServiceName.Buffer -
(ULONG_PTR)LoadParams->ServiceName->Buffer);
break;
}
cur--;
}
/*
* Get service type.
*/
RtlZeroMemory(&QueryTable, sizeof(QueryTable));
RtlInitUnicodeString(&ImagePath, NULL);
QueryTable[0].Name = L"Type";
QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT | RTL_QUERY_REGISTRY_REQUIRED;
QueryTable[0].EntryContext = &Type;
QueryTable[1].Name = L"ImagePath";
QueryTable[1].Flags = RTL_QUERY_REGISTRY_DIRECT;
QueryTable[1].EntryContext = &ImagePath;
Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE,
LoadParams->ServiceName->Buffer, QueryTable, NULL, NULL);
if (!NT_SUCCESS(Status))
{
DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
ExFreePool(ImagePath.Buffer);
LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
}
/*
* Normalize the image path for all later processing.
*/
Status = IopNormalizeImagePath(&ImagePath, &ServiceName);
if (!NT_SUCCESS(Status))
{
DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status);
LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
}
DPRINT("FullImagePath: '%wZ'\n", &ImagePath);
DPRINT("Type: %lx\n", Type);
/*
* Create device node
*/
/* Use IopRootDeviceNode for now */
Status = IopCreateDeviceNode(IopRootDeviceNode, NULL, &ServiceName, &DeviceNode);
if (!NT_SUCCESS(Status))
{
DPRINT("IopCreateDeviceNode() failed (Status %lx)\n", Status);
LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
}
/* Get existing DriverObject pointer (in case the driver has
already been loaded and initialized) */
Status = IopGetDriverObject(
&DriverObject,
&ServiceName,
(Type == 2 /* SERVICE_FILE_SYSTEM_DRIVER */ ||
Type == 8 /* SERVICE_RECOGNIZER_DRIVER */));
if (!NT_SUCCESS(Status))
{
/*
* Load the driver module
*/
Status = MmLoadSystemImage(&ImagePath, NULL, NULL, 0, (PVOID)&ModuleObject, NULL);
if (!NT_SUCCESS(Status) && Status != STATUS_IMAGE_ALREADY_LOADED)
{
DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status);
IopFreeDeviceNode(DeviceNode);
LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
}
/*
* Set a service name for the device node
*/
RtlCreateUnicodeString(&DeviceNode->ServiceName, ServiceName.Buffer);
/*
* Initialize the driver module if it's loaded for the first time
*/
if (Status != STATUS_IMAGE_ALREADY_LOADED)
{
Status = IopInitializeDriverModule(
DeviceNode,
ModuleObject,
&DeviceNode->ServiceName,
(Type == 2 /* SERVICE_FILE_SYSTEM_DRIVER */ ||
Type == 8 /* SERVICE_RECOGNIZER_DRIVER */),
&DriverObject);
if (!NT_SUCCESS(Status))
{
DPRINT("IopInitializeDriver() failed (Status %lx)\n", Status);
MmUnloadSystemImage(ModuleObject);
IopFreeDeviceNode(DeviceNode);
LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
}
}
/* We have a driver for this DeviceNode */
DeviceNode->Flags |= DN_DRIVER_LOADED;
}
IopInitializeDevice(DeviceNode, DriverObject);
LoadParams->Status = IopStartDevice(DeviceNode);
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
}
/*
* NtLoadDriver
*
@ -1483,17 +1642,10 @@ IoGetDriverObjectExtension(IN PDRIVER_OBJECT DriverObject,
NTSTATUS STDCALL
NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
{
RTL_QUERY_REGISTRY_TABLE QueryTable[3];
UNICODE_STRING ImagePath;
UNICODE_STRING ServiceName;
UNICODE_STRING CapturedDriverServiceName = {0};
KPROCESSOR_MODE PreviousMode;
LOAD_UNLOAD_PARAMS LoadParams;
NTSTATUS Status;
ULONG Type;
PDEVICE_NODE DeviceNode;
PLDR_DATA_TABLE_ENTRY ModuleObject;
PDRIVER_OBJECT DriverObject;
WCHAR *cur;
PAGED_CODE();
@ -1522,144 +1674,34 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
DPRINT("NtLoadDriver('%wZ')\n", &CapturedDriverServiceName);
RtlInitUnicodeString(&ImagePath, NULL);
LoadParams.ServiceName = &CapturedDriverServiceName;
KeInitializeEvent(&LoadParams.Event, NotificationEvent, FALSE);
/*
* Get the service name from the registry key name.
*/
ASSERT(CapturedDriverServiceName.Length >= sizeof(WCHAR));
ServiceName = CapturedDriverServiceName;
cur = CapturedDriverServiceName.Buffer + (CapturedDriverServiceName.Length / sizeof(WCHAR)) - 1;
while (CapturedDriverServiceName.Buffer != cur)
/* Call the load/unload routine, depending on current process */
if (PsGetCurrentProcess() == PsInitialSystemProcess)
{
if(*cur == L'\\')
{
ServiceName.Buffer = cur + 1;
ServiceName.Length = CapturedDriverServiceName.Length -
(USHORT)((ULONG_PTR)ServiceName.Buffer -
(ULONG_PTR)CapturedDriverServiceName.Buffer);
break;
/* Just call right away */
IopLoadUnloadDriver(&LoadParams);
}
cur--;
else
{
/* Load/Unload must be called from system process */
ExInitializeWorkItem(&LoadParams.WorkItem,
(PWORKER_THREAD_ROUTINE)IopLoadUnloadDriver,
(PVOID)&LoadParams);
/* Queue it */
ExQueueWorkItem(&LoadParams.WorkItem, DelayedWorkQueue);
/* And wait when it completes */
KeWaitForSingleObject(&LoadParams.Event, UserRequest, KernelMode,
FALSE, NULL);
}
/*
* Get service type.
*/
RtlZeroMemory(&QueryTable, sizeof(QueryTable));
RtlInitUnicodeString(&ImagePath, NULL);
QueryTable[0].Name = L"Type";
QueryTable[0].Flags = RTL_QUERY_REGISTRY_DIRECT | RTL_QUERY_REGISTRY_REQUIRED;
QueryTable[0].EntryContext = &Type;
QueryTable[1].Name = L"ImagePath";
QueryTable[1].Flags = RTL_QUERY_REGISTRY_DIRECT;
QueryTable[1].EntryContext = &ImagePath;
Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE,
CapturedDriverServiceName.Buffer, QueryTable, NULL, NULL);
if (!NT_SUCCESS(Status))
{
DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
ExFreePool(ImagePath.Buffer);
goto ReleaseCapturedString;
}
/*
* Normalize the image path for all later processing.
*/
Status = IopNormalizeImagePath(&ImagePath, &ServiceName);
if (!NT_SUCCESS(Status))
{
DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status);
goto ReleaseCapturedString;
}
DPRINT("FullImagePath: '%wZ'\n", &ImagePath);
DPRINT("Type: %lx\n", Type);
/*
* Create device node
*/
/* Use IopRootDeviceNode for now */
Status = IopCreateDeviceNode(IopRootDeviceNode, NULL, &ServiceName, &DeviceNode);
if (!NT_SUCCESS(Status))
{
DPRINT("IopCreateDeviceNode() failed (Status %lx)\n", Status);
goto ReleaseCapturedString;
}
/* Get existing DriverObject pointer (in case the driver has
already been loaded and initialized) */
Status = IopGetDriverObject(
&DriverObject,
&ServiceName,
(Type == 2 /* SERVICE_FILE_SYSTEM_DRIVER */ ||
Type == 8 /* SERVICE_RECOGNIZER_DRIVER */));
if (!NT_SUCCESS(Status))
{
/*
* Load the driver module
*/
Status = MmLoadSystemImage(&ImagePath, NULL, NULL, 0, (PVOID)&ModuleObject, NULL);
if (!NT_SUCCESS(Status) && Status != STATUS_IMAGE_ALREADY_LOADED)
{
DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status);
IopFreeDeviceNode(DeviceNode);
goto ReleaseCapturedString;
}
/*
* Set a service name for the device node
*/
RtlCreateUnicodeString(&DeviceNode->ServiceName, ServiceName.Buffer);
/*
* Initialize the driver module if it's loaded for the first time
*/
if (Status != STATUS_IMAGE_ALREADY_LOADED)
{
Status = IopInitializeDriverModule(
DeviceNode,
ModuleObject,
&DeviceNode->ServiceName,
(Type == 2 /* SERVICE_FILE_SYSTEM_DRIVER */ ||
Type == 8 /* SERVICE_RECOGNIZER_DRIVER */),
&DriverObject);
if (!NT_SUCCESS(Status))
{
DPRINT("IopInitializeDriver() failed (Status %lx)\n", Status);
MmUnloadSystemImage(ModuleObject);
IopFreeDeviceNode(DeviceNode);
goto ReleaseCapturedString;
}
}
/* We have a driver for this DeviceNode */
DeviceNode->Flags |= DN_DRIVER_LOADED;
}
IopInitializeDevice(DeviceNode, DriverObject);
Status = IopStartDevice(DeviceNode);
ReleaseCapturedString:
ReleaseCapturedUnicodeString(&CapturedDriverServiceName,
PreviousMode);
return Status;
return LoadParams.Status;
}
/*