- NtLoadDriver APIs refactor/improvent:

* Move loading of a driver to the common (in future) routine IopLoadUnloadDriver (name taken from http://wasm.ru/forum/viewtopic.php?pid=166891).
 * Fix a bug: NtLoadDriver should always load drivers in a context of the system process, explanation here: http://www.osronline.com/showthread.cfm?link=114687
 * Reformat NtLoadDriver's code.

svn path=/trunk/; revision=29370
This commit is contained in:
Aleksey Bragin 2007-10-03 10:17:04 +00:00
parent 597b822c7e
commit 2ea98ca374
2 changed files with 116 additions and 63 deletions

View file

@ -349,6 +349,17 @@ typedef struct _OPEN_PACKET
//PIO_DRIVER_CREATE_CONTEXT DriverCreateContext; Vista only, needs ROS DDK Update //PIO_DRIVER_CREATE_CONTEXT DriverCreateContext; Vista only, needs ROS DDK Update
} OPEN_PACKET, *POPEN_PACKET; } OPEN_PACKET, *POPEN_PACKET;
//
// Parameters packet for Load/Unload work item's context
//
typedef struct _LOAD_UNLOAD_PARAMS
{
NTSTATUS Status;
PUNICODE_STRING ServiceName;
WORK_QUEUE_ITEM WorkItem;
KEVENT Event;
} LOAD_UNLOAD_PARAMS, *PLOAD_UNLOAD_PARAMS;
// //
// List of Bus Type GUIDs // List of Bus Type GUIDs
// //

View file

@ -1465,29 +1465,12 @@ IoGetDriverObjectExtension(IN PDRIVER_OBJECT DriverObject,
return DriverExtensions + 1; return DriverExtensions + 1;
} }
/* VOID NTAPI
* NtLoadDriver IopLoadUnloadDriver(PLOAD_UNLOAD_PARAMS LoadParams)
*
* Loads a device driver.
*
* Parameters
* DriverServiceName
* Name of the service to load (registry key).
*
* Return Value
* Status
*
* Status
* implemented
*/
NTSTATUS STDCALL
NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
{ {
RTL_QUERY_REGISTRY_TABLE QueryTable[3]; RTL_QUERY_REGISTRY_TABLE QueryTable[3];
UNICODE_STRING ImagePath; UNICODE_STRING ImagePath;
UNICODE_STRING ServiceName; UNICODE_STRING ServiceName;
UNICODE_STRING CapturedDriverServiceName = {0};
KPROCESSOR_MODE PreviousMode;
NTSTATUS Status; NTSTATUS Status;
ULONG Type; ULONG Type;
PDEVICE_NODE DeviceNode; PDEVICE_NODE DeviceNode;
@ -1495,50 +1478,24 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
PDRIVER_OBJECT DriverObject; PDRIVER_OBJECT DriverObject;
WCHAR *cur; WCHAR *cur;
PAGED_CODE();
PreviousMode = KeGetPreviousMode();
/*
* Check security privileges
*/
/* FIXME: Uncomment when privileges will be correctly implemented. */
#if 0
if (!SeSinglePrivilegeCheck(SeLoadDriverPrivilege, PreviousMode))
{
DPRINT("Privilege not held\n");
return STATUS_PRIVILEGE_NOT_HELD;
}
#endif
Status = ProbeAndCaptureUnicodeString(&CapturedDriverServiceName,
PreviousMode,
DriverServiceName);
if (!NT_SUCCESS(Status))
{
return Status;
}
DPRINT("NtLoadDriver('%wZ')\n", &CapturedDriverServiceName);
RtlInitUnicodeString(&ImagePath, NULL); RtlInitUnicodeString(&ImagePath, NULL);
/* /*
* Get the service name from the registry key name. * Get the service name from the registry key name.
*/ */
ASSERT(CapturedDriverServiceName.Length >= sizeof(WCHAR)); ASSERT(LoadParams->ServiceName->Length >= sizeof(WCHAR));
ServiceName = CapturedDriverServiceName; ServiceName = *LoadParams->ServiceName;
cur = CapturedDriverServiceName.Buffer + (CapturedDriverServiceName.Length / sizeof(WCHAR)) - 1; cur = LoadParams->ServiceName->Buffer +
while (CapturedDriverServiceName.Buffer != cur) (LoadParams->ServiceName->Length / sizeof(WCHAR)) - 1;
while (LoadParams->ServiceName->Buffer != cur)
{ {
if(*cur == L'\\') if(*cur == L'\\')
{ {
ServiceName.Buffer = cur + 1; ServiceName.Buffer = cur + 1;
ServiceName.Length = CapturedDriverServiceName.Length - ServiceName.Length = LoadParams->ServiceName->Length -
(USHORT)((ULONG_PTR)ServiceName.Buffer - (USHORT)((ULONG_PTR)ServiceName.Buffer -
(ULONG_PTR)CapturedDriverServiceName.Buffer); (ULONG_PTR)LoadParams->ServiceName->Buffer);
break; break;
} }
cur--; cur--;
@ -1561,13 +1518,15 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
QueryTable[1].EntryContext = &ImagePath; QueryTable[1].EntryContext = &ImagePath;
Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE, Status = RtlQueryRegistryValues(RTL_REGISTRY_ABSOLUTE,
CapturedDriverServiceName.Buffer, QueryTable, NULL, NULL); LoadParams->ServiceName->Buffer, QueryTable, NULL, NULL);
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status); DPRINT("RtlQueryRegistryValues() failed (Status %lx)\n", Status);
ExFreePool(ImagePath.Buffer); ExFreePool(ImagePath.Buffer);
goto ReleaseCapturedString; LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
} }
/* /*
@ -1579,7 +1538,9 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status); DPRINT("IopNormalizeImagePath() failed (Status %x)\n", Status);
goto ReleaseCapturedString; LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
} }
DPRINT("FullImagePath: '%wZ'\n", &ImagePath); DPRINT("FullImagePath: '%wZ'\n", &ImagePath);
@ -1595,7 +1556,9 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
if (!NT_SUCCESS(Status)) if (!NT_SUCCESS(Status))
{ {
DPRINT("IopCreateDeviceNode() failed (Status %lx)\n", Status); DPRINT("IopCreateDeviceNode() failed (Status %lx)\n", Status);
goto ReleaseCapturedString; LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
} }
/* Get existing DriverObject pointer (in case the driver has /* Get existing DriverObject pointer (in case the driver has
@ -1617,7 +1580,9 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
{ {
DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status); DPRINT("MmLoadSystemImage() failed (Status %lx)\n", Status);
IopFreeDeviceNode(DeviceNode); IopFreeDeviceNode(DeviceNode);
goto ReleaseCapturedString; LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
} }
/* /*
@ -1644,7 +1609,9 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
DPRINT("IopInitializeDriver() failed (Status %lx)\n", Status); DPRINT("IopInitializeDriver() failed (Status %lx)\n", Status);
MmUnloadSystemImage(ModuleObject); MmUnloadSystemImage(ModuleObject);
IopFreeDeviceNode(DeviceNode); IopFreeDeviceNode(DeviceNode);
goto ReleaseCapturedString; LoadParams->Status = Status;
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
return;
} }
} }
@ -1653,13 +1620,88 @@ NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
} }
IopInitializeDevice(DeviceNode, DriverObject); IopInitializeDevice(DeviceNode, DriverObject);
Status = IopStartDevice(DeviceNode); LoadParams->Status = IopStartDevice(DeviceNode);
(VOID)KeSetEvent(&LoadParams->Event, 0, FALSE);
}
ReleaseCapturedString: /*
ReleaseCapturedUnicodeString(&CapturedDriverServiceName, * NtLoadDriver
PreviousMode); *
* Loads a device driver.
*
* Parameters
* DriverServiceName
* Name of the service to load (registry key).
*
* Return Value
* Status
*
* Status
* implemented
*/
NTSTATUS STDCALL
NtLoadDriver(IN PUNICODE_STRING DriverServiceName)
{
UNICODE_STRING CapturedDriverServiceName = {0};
KPROCESSOR_MODE PreviousMode;
LOAD_UNLOAD_PARAMS LoadParams;
NTSTATUS Status;
return Status; PAGED_CODE();
PreviousMode = KeGetPreviousMode();
/*
* Check security privileges
*/
/* FIXME: Uncomment when privileges will be correctly implemented. */
#if 0
if (!SeSinglePrivilegeCheck(SeLoadDriverPrivilege, PreviousMode))
{
DPRINT("Privilege not held\n");
return STATUS_PRIVILEGE_NOT_HELD;
}
#endif
Status = ProbeAndCaptureUnicodeString(&CapturedDriverServiceName,
PreviousMode,
DriverServiceName);
if (!NT_SUCCESS(Status))
{
return Status;
}
DPRINT("NtLoadDriver('%wZ')\n", &CapturedDriverServiceName);
LoadParams.ServiceName = &CapturedDriverServiceName;
KeInitializeEvent(&LoadParams.Event, NotificationEvent, FALSE);
/* Call the load/unload routine, depending on current process */
if (PsGetCurrentProcess() == PsInitialSystemProcess)
{
/* Just call right away */
IopLoadUnloadDriver(&LoadParams);
}
else
{
/* Load/Unload must be called from system process */
ExInitializeWorkItem(&LoadParams.WorkItem,
(PWORKER_THREAD_ROUTINE)IopLoadUnloadDriver,
(PVOID)&LoadParams);
/* Queue it */
ExQueueWorkItem(&LoadParams.WorkItem, DelayedWorkQueue);
/* And wait when it completes */
KeWaitForSingleObject(&LoadParams.Event, UserRequest, KernelMode,
FALSE, NULL);
}
ReleaseCapturedUnicodeString(&CapturedDriverServiceName,
PreviousMode);
return LoadParams.Status;
} }
/* /*