diff --git a/kmtests/CMakeLists.txt b/kmtests/CMakeLists.txt index fbb3aec1a31..78ba94eb437 100644 --- a/kmtests/CMakeLists.txt +++ b/kmtests/CMakeLists.txt @@ -4,7 +4,7 @@ include_directories( # # subdirectories containing special-purpose drivers # -#add_subdirectory(something) +add_subdirectory(Example) # # kmtest_drv.sys driver diff --git a/kmtests/directory.rbuild b/kmtests/directory.rbuild index 5d6c89d4c4b..3c42f836e24 100644 --- a/kmtests/directory.rbuild +++ b/kmtests/directory.rbuild @@ -1,9 +1,9 @@ - + + + diff --git a/kmtests/example/CMakeLists.txt b/kmtests/example/CMakeLists.txt new file mode 100644 index 00000000000..3ac0f80e170 --- /dev/null +++ b/kmtests/example/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories( + ../include) + +list(APPEND EXAMPLE_DRV_SOURCE + ../kmtest_drv/kmtest_standalone.c + Example_drv.c) + +add_library(example_drv SHARED ${EXAMPLE_DRV_SOURCE}) + +set_module_type(example_drv kernelmodedriver) +target_link_libraries(example_drv kmtest_printf ${PSEH_LIB}) +add_importlibs(example_drv ntoskrnl hal) +set_property(TARGET example_drv PROPERTY COMPILE_DEFINITIONS KMT_STANDALONE_DRIVER) + +add_cd_file(TARGET example_drv DESTINATION reactos/system32/drivers FOR all) diff --git a/kmtests/example/Example.h b/kmtests/example/Example.h new file mode 100644 index 00000000000..7c7bfec9494 --- /dev/null +++ b/kmtests/example/Example.h @@ -0,0 +1,21 @@ +/* + * PROJECT: ReactOS kernel-mode tests + * LICENSE: GPLv2+ - See COPYING in the top level directory + * PURPOSE: Kernel-Mode Test Suite Example Test declarations + * PROGRAMMER: Thomas Faber + */ + +#ifndef _KMTEST_EXAMPLE_H_ +#define _KMTEST_EXAMPLE_H_ + +typedef struct +{ + int a; + char b[8]; +} MY_STRUCT, *PMY_STRUCT; + +#define IOCTL_NOTIFY 1 +#define IOCTL_SEND_STRING 2 +#define IOCTL_SEND_MYSTRUCT 3 + +#endif /* !defined _KMTEST_EXAMPLE_H_ */ diff --git a/kmtests/example/Example_drv.c b/kmtests/example/Example_drv.c new file mode 100644 index 00000000000..dd7504e01f3 --- /dev/null +++ b/kmtests/example/Example_drv.c @@ -0,0 +1,246 @@ +/* + * PROJECT: ReactOS kernel-mode tests + * LICENSE: GPLv2+ - See COPYING in the top level directory + * PURPOSE: Kernel-Mode Test Suite Example Test Driver + * PROGRAMMER: Thomas Faber + */ + +#include + +#include + +//#define NDEBUG +#include + +#include "Example.h" + +/* prototypes */ +static KMT_MESSAGE_HANDLER TestMessageHandler; +static KMT_IRP_HANDLER TestIrpHandler; + +/* globals */ +static PDRIVER_OBJECT TestDriverObject; + +/** + * @name TestEntry + * + * Test entry point. + * This is called by DriverEntry as early as possible, but with ResultBuffer + * initialized, so that test macros work correctly + * + * @param DriverObject + * Driver Object. + * This is guaranteed not to have been touched by DriverEntry before + * the call to TestEntry + * @param RegistryPath + * Driver Registry Path + * This is guaranteed not to have been touched by DriverEntry before + * the call to TestEntry + * @param DeviceName + * Pointer to receive a test-specific name for the device to create + * @param Flags + * Pointer to a flags variable instructing DriverEntry how to proceed. + * See the KMT_TESTENTRY_FLAGS enumeration for possible values + * Initialized to zero on entry + * + * @return Status. + * DriverEntry will fail if this is a failure status + */ +NTSTATUS +TestEntry( + IN PDRIVER_OBJECT DriverObject, + IN PCUNICODE_STRING RegistryPath, + OUT PCWSTR *DeviceName, + OUT INT *Flags) +{ + NTSTATUS Status = STATUS_SUCCESS; + + PAGED_CODE(); + + UNREFERENCED_PARAMETER(RegistryPath); + + DPRINT("Entry!\n"); + + ok_irql(PASSIVE_LEVEL); + TestDriverObject = DriverObject; + + *DeviceName = L"Example"; + + trace("Hi, this is the example driver\n"); + + KmtRegisterIrpHandler(IRP_MJ_CREATE, NULL, TestIrpHandler); + KmtRegisterIrpHandler(IRP_MJ_CLOSE, NULL, TestIrpHandler); + KmtRegisterMessageHandler(0, NULL, TestMessageHandler); + + return Status; +} + +/** + * @name TestUnload + * + * Test unload routine. + * This is called by the driver's Unload routine as early as possible, with + * ResultBuffer and the test device object still valid, so that test macros + * work correctly + * + * @param DriverObject + * Driver Object. + * This is guaranteed not to have been touched by Unload before the call + * to TestEntry + * + * @return Status + */ +VOID +TestUnload( + IN PDRIVER_OBJECT DriverObject) +{ + PAGED_CODE(); + + DPRINT("Unload!\n"); + + ok_irql(PASSIVE_LEVEL); + ok_eq_pointer(DriverObject, TestDriverObject); + + trace("Unloading example driver\n"); +} + +/** + * @name TestMessageHandler + * + * Test message handler routine + * + * @param DeviceObject + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler + * @param Irp + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler, except for passing it to + * IoGetCurrentStackLocation + * @param IoStackLocation + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler + * + * @return Status + */ +static +NTSTATUS +TestMessageHandler( + IN PDEVICE_OBJECT DeviceObject, + IN ULONG ControlCode, + IN PVOID Buffer OPTIONAL, + IN SIZE_T InLength, + IN OUT PSIZE_T OutLength) +{ + NTSTATUS Status = STATUS_SUCCESS; + + switch (ControlCode) + { + case IOCTL_NOTIFY: + { + static int TimesReceived = 0; + + ++TimesReceived; + ok(TimesReceived == 1, "Received control code 1 %d times\n", TimesReceived); + ok_eq_pointer(Buffer, NULL); + ok_eq_ulong((ULONG)InLength, 0LU); + ok_eq_ulong((ULONG)*OutLength, 0LU); + break; + } + case IOCTL_SEND_STRING: + { + static int TimesReceived = 0; + ANSI_STRING ExpectedString = RTL_CONSTANT_STRING("yay"); + ANSI_STRING ReceivedString; + + ++TimesReceived; + ok(TimesReceived == 1, "Received control code 2 %d times\n", TimesReceived); + ok(Buffer != NULL, "Buffer is NULL\n"); + ok_eq_ulong((ULONG)InLength, (ULONG)ExpectedString.Length); + ok_eq_ulong((ULONG)*OutLength, 0LU); + ReceivedString.MaximumLength = ReceivedString.Length = (USHORT)InLength; + ReceivedString.Buffer = Buffer; + ok(RtlCompareString(&ExpectedString, &ReceivedString, FALSE) == 0, "Received string: %Z\n", &ReceivedString); + break; + } + case IOCTL_SEND_MYSTRUCT: + { + static int TimesReceived = 0; + MY_STRUCT ExpectedStruct = { 123, ":D" }; + MY_STRUCT ResultStruct = { 456, "!!!" }; + + ++TimesReceived; + ok(TimesReceived == 1, "Received control code 3 %d times\n", TimesReceived); + ok(Buffer != NULL, "Buffer is NULL\n"); + ok_eq_ulong((ULONG)InLength, (ULONG)sizeof ExpectedStruct); + ok_eq_ulong((ULONG)*OutLength, 2LU * sizeof ExpectedStruct); + if (!skip(Buffer && InLength >= sizeof ExpectedStruct, "Cannot read from buffer!\n")) + ok(RtlCompareMemory(&ExpectedStruct, Buffer, sizeof ExpectedStruct) == sizeof ExpectedStruct, "Buffer does not contain expected values\n"); + + if (!skip(Buffer && *OutLength >= 2 * sizeof ExpectedStruct, "Cannot write to buffer!\n")) + { + RtlCopyMemory((PCHAR)Buffer + sizeof ExpectedStruct, &ResultStruct, sizeof ResultStruct); + *OutLength = 2 * sizeof ExpectedStruct; + } + break; + } + default: + ok(0, "Got an unknown message! DeviceObject=%p, ControlCode=%lu, Buffer=%p, In=%lu, Out=%lu bytes\n", + DeviceObject, ControlCode, Buffer, InLength, *OutLength); + break; + } + + return Status; +} + +/** + * @name TestIrpHandler + * + * Test IRP handler routine + * + * @param DeviceObject + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler + * @param Irp + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler, except for passing it to + * IoGetCurrentStackLocation + * @param IoStackLocation + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler + * + * @return Status + */ +static +NTSTATUS +TestIrpHandler( + IN PDEVICE_OBJECT DeviceObject, + IN PIRP Irp, + IN PIO_STACK_LOCATION IoStackLocation) +{ + NTSTATUS Status = STATUS_SUCCESS; + + DPRINT("IRP!\n"); + + ok_irql(PASSIVE_LEVEL); + ok_eq_pointer(DeviceObject->DriverObject, TestDriverObject); + + if (IoStackLocation->MajorFunction == IRP_MJ_CREATE) + trace("Got IRP_MJ_CREATE!\n"); + else if (IoStackLocation->MajorFunction == IRP_MJ_CLOSE) + trace("Got IRP_MJ_CLOSE!\n"); + else + trace("Got an IRP!\n"); + + Irp->IoStatus.Status = Status; + Irp->IoStatus.Information = 0; + + IoCompleteRequest(Irp, IO_NO_INCREMENT); + + return Status; +} diff --git a/kmtests/example/Example_user.c b/kmtests/example/Example_user.c index e5794eab257..6f3d7bfe229 100644 --- a/kmtests/example/Example_user.c +++ b/kmtests/example/Example_user.c @@ -10,10 +10,14 @@ #include #include +#include "Example.h" + START_TEST(Example) { /* do some user-mode stuff */ SYSTEM_INFO SystemInfo; + MY_STRUCT MyStruct[2] = { { 123, ":D" }, { 0 } }; + DWORD Length = sizeof MyStruct; trace("Message from user-mode\n"); @@ -23,4 +27,21 @@ START_TEST(Example) /* now run the kernel-mode part (see Example.c). * If no user-mode part exists, this is what's done automatically */ KmtRunKernelTest("Example"); + + /* now start the special-purpose driver */ + KmtLoadDriver(L"Example", FALSE); + trace("After Entry\n"); + KmtOpenDriver(); + trace("After Create\n"); + + ok(KmtSendToDriver(IOCTL_NOTIFY) == ERROR_SUCCESS, "\n"); + ok(KmtSendStringToDriver(IOCTL_SEND_STRING, "yay") == ERROR_SUCCESS, "\n"); + ok(KmtSendBufferToDriver(IOCTL_SEND_MYSTRUCT, MyStruct, sizeof MyStruct[0], &Length) == ERROR_SUCCESS, "\n"); + ok_eq_int(MyStruct[1].a, 456); + ok_eq_str(MyStruct[1].b, "!!!"); + + KmtCloseDriver(); + trace("After Close\n"); + KmtUnloadDriver(); + trace("After Unload\n"); } diff --git a/kmtests/example/example_drv.rbuild b/kmtests/example/example_drv.rbuild new file mode 100644 index 00000000000..c6041e8aee2 --- /dev/null +++ b/kmtests/example/example_drv.rbuild @@ -0,0 +1,14 @@ + + include + ntoskrnl + hal + pseh + kmtest_printf + + Example_drv.c + + + kmtest_standalone.c + + + diff --git a/kmtests/include/kmt_test.h b/kmtests/include/kmt_test.h index 0e57fc2e259..dd4311c228f 100644 --- a/kmtests/include/kmt_test.h +++ b/kmtests/include/kmt_test.h @@ -38,6 +38,40 @@ typedef struct CHAR LogBuffer[ANYSIZE_ARRAY]; } KMT_RESULTBUFFER, *PKMT_RESULTBUFFER; +#ifdef KMT_STANDALONE_DRIVER +#define KMT_KERNEL_MODE + +typedef NTSTATUS (KMT_IRP_HANDLER)( + IN PDEVICE_OBJECT DeviceObject, + IN PIRP Irp, + IN PIO_STACK_LOCATION IoStackLocation); +typedef KMT_IRP_HANDLER *PKMT_IRP_HANDLER; + +NTSTATUS KmtRegisterIrpHandler(IN UCHAR MajorFunction, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_IRP_HANDLER IrpHandler); +NTSTATUS KmtUnregisterIrpHandler(IN UCHAR MajorFunction, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_IRP_HANDLER IrpHandler); + +typedef NTSTATUS (KMT_MESSAGE_HANDLER)( + IN PDEVICE_OBJECT DeviceObject, + IN ULONG ControlCode, + IN PVOID Buffer OPTIONAL, + IN SIZE_T InLength, + IN OUT PSIZE_T OutLength); +typedef KMT_MESSAGE_HANDLER *PKMT_MESSAGE_HANDLER; + +NTSTATUS KmtRegisterMessageHandler(IN ULONG ControlCode OPTIONAL, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_MESSAGE_HANDLER MessageHandler); +NTSTATUS KmtUnregisterMessageHandler(IN ULONG ControlCode OPTIONAL, IN PDEVICE_OBJECT DeviceObject OPTIONAL, IN PKMT_MESSAGE_HANDLER MessageHandler); + +typedef enum +{ + TESTENTRY_NO_CREATE_DEVICE = 1, + TESTENTRY_NO_REGISTER_DISPATCH = 2, + TESTENTRY_NO_REGISTER_UNLOAD = 4, +} KMT_TESTENTRY_FLAGS; + +NTSTATUS TestEntry(IN PDRIVER_OBJECT DriverObject, IN PCUNICODE_STRING RegistryPath, OUT PCWSTR *DeviceName, OUT INT *Flags); +VOID TestUnload(IN PDRIVER_OBJECT DriverObject); +#endif /* defined KMT_STANDALONE_DRIVER */ + #ifdef KMT_KERNEL_MODE /* Device Extension layout */ typedef struct @@ -55,7 +89,7 @@ VOID KmtCloseDriver(VOID); DWORD KmtSendToDriver(IN DWORD ControlCode); DWORD KmtSendStringToDriver(IN DWORD ControlCode, IN PCSTR String); -DWORD KmtSendBufferToDriver(IN DWORD ControlCode, IN OUT PVOID Buffer, IN OUT PDWORD Length); +DWORD KmtSendBufferToDriver(IN DWORD ControlCode, IN OUT PVOID Buffer OPTIONAL, IN DWORD InLength, IN OUT PDWORD OutLength); #endif /* defined KMT_USER_MODE */ extern PKMT_RESULTBUFFER ResultBuffer; @@ -105,7 +139,7 @@ BOOLEAN KmtSkip(INT Condition, PCSTR FileAndLine, PCSTR Format, ...) #define ok_eq_wstr(value, expected) ok(!wcscmp(value, expected), #value " = \"%ls\", expected \"%ls\"\n", value, expected) #define KMT_MAKE_CODE(ControlCode) CTL_CODE(FILE_DEVICE_UNKNOWN, \ - 0xA00 + (ControlCode), \ + 0xC00 + (ControlCode), \ METHOD_BUFFERED, \ FILE_ANY_ACCESS) diff --git a/kmtests/kmtest/support.c b/kmtests/kmtest/support.c index 213a1e28c90..bf55ed7e135 100644 --- a/kmtests/kmtest/support.c +++ b/kmtests/kmtest/support.c @@ -192,6 +192,8 @@ KmtSendToDriver( { DWORD BytesRead; + assert(ControlCode < 0x400); + if (!DeviceIoControl(TestDeviceHandle, KMT_MAKE_CODE(ControlCode), NULL, 0, NULL, 0, &BytesRead, NULL)) return GetLastError(); @@ -215,6 +217,8 @@ KmtSendStringToDriver( { DWORD BytesRead; + assert(ControlCode < 0x400); + if (!DeviceIoControl(TestDeviceHandle, KMT_MAKE_CODE(ControlCode), (PVOID)String, strlen(String), NULL, 0, &BytesRead, NULL)) return GetLastError(); @@ -226,19 +230,23 @@ KmtSendStringToDriver( * * @param ControlCode * @param Buffer - * @param Length + * @param InLength + * @param OutLength * * @return Win32 error code as returned by DeviceIoControl */ DWORD KmtSendBufferToDriver( IN DWORD ControlCode, - IN OUT PVOID Buffer, - IN OUT PDWORD Length) + IN OUT PVOID Buffer OPTIONAL, + IN DWORD InLength, + IN OUT PDWORD OutLength) { - assert(Length); + assert(OutLength); + assert(Buffer || (!InLength && !*OutLength)); + assert(ControlCode < 0x400); - if (!DeviceIoControl(TestDeviceHandle, KMT_MAKE_CODE(ControlCode), Buffer, *Length, NULL, 0, Length, NULL)) + if (!DeviceIoControl(TestDeviceHandle, KMT_MAKE_CODE(ControlCode), Buffer, InLength, Buffer, *OutLength, OutLength, NULL)) return GetLastError(); return ERROR_SUCCESS; diff --git a/kmtests/kmtest_drv.rbuild b/kmtests/kmtest_drv.rbuild index a243df6735d..c16db22d603 100644 --- a/kmtests/kmtest_drv.rbuild +++ b/kmtests/kmtest_drv.rbuild @@ -1,7 +1,6 @@ include ntoskrnl - ntdll hal pseh kmtest_printf diff --git a/kmtests/kmtest_drv/kmtest_standalone.c b/kmtests/kmtest_drv/kmtest_standalone.c new file mode 100644 index 00000000000..b3e856b02a9 --- /dev/null +++ b/kmtests/kmtest_drv/kmtest_standalone.c @@ -0,0 +1,513 @@ +/* + * PROJECT: ReactOS kernel-mode tests + * LICENSE: GPLv2+ - See COPYING in the top level directory + * PURPOSE: Kernel-Mode Test Suite Example Test Driver + * PROGRAMMER: Thomas Faber + */ + +#include + +#define KMT_DEFINE_TEST_FUNCTIONS +#include + +//#define NDEBUG +#include + +#include + +/* types */ +typedef struct +{ + UCHAR MajorFunction; + PDEVICE_OBJECT DeviceObject; + PKMT_IRP_HANDLER IrpHandler; +} KMT_IRP_HANDLER_ENTRY, *PKMT_IRP_HANDLER_ENTRY; + +typedef struct +{ + ULONG ControlCode; + PDEVICE_OBJECT DeviceObject; + PKMT_MESSAGE_HANDLER MessageHandler; +} KMT_MESSAGE_HANDLER_ENTRY, *PKMT_MESSAGE_HANDLER_ENTRY; + +/* Prototypes */ +DRIVER_INITIALIZE DriverEntry; +static DRIVER_UNLOAD DriverUnload; +static DRIVER_DISPATCH DriverDispatch; +static KMT_IRP_HANDLER DeviceControlHandler; + +/* Globals */ +static PDEVICE_OBJECT TestDeviceObject; +static PDEVICE_OBJECT KmtestDeviceObject; + +#define KMT_MAX_IRP_HANDLERS 256 +static KMT_IRP_HANDLER_ENTRY IrpHandlers[KMT_MAX_IRP_HANDLERS] = { { 0 } }; +#define KMT_MAX_MESSAGE_HANDLERS 256 +static KMT_MESSAGE_HANDLER_ENTRY MessageHandlers[KMT_MAX_MESSAGE_HANDLERS] = { { 0 } }; + +static const char *IrpMajorFunctionNames[] = +{ + "Create", + "CreateNamedPipe", + "Close", + "Read", + "Write", + "QueryInformation", + "SetInformation", + "QueryEa", + "SetEa", + "FlushBuffers", + "QueryVolumeInformation", + "SetVolumeInformation", + "DirectoryControl", + "FileSystemControl", + "DeviceControl", + "InternalDeviceControl/Scsi", + "Shutdown", + "LockControl", + "Cleanup", + "CreateMailslot", + "QuerySecurity", + "SetSecurity", + "Power", + "SystemControl", + "DeviceChange", + "QueryQuota", + "SetQuota", + "Pnp/PnpPower" +}; + +/** + * @name DriverEntry + * + * Driver entry point. + * + * @param DriverObject + * Driver Object + * @param RegistryPath + * Driver Registry Path + * + * @return Status + */ +NTSTATUS +NTAPI +DriverEntry( + IN PDRIVER_OBJECT DriverObject, + IN PUNICODE_STRING RegistryPath) +{ + NTSTATUS Status = STATUS_SUCCESS; + WCHAR DeviceNameBuffer[128] = L"\\Device\\Kmtest-"; + UNICODE_STRING KmtestDeviceName; + PFILE_OBJECT KmtestFileObject; + PKMT_DEVICE_EXTENSION KmtestDeviceExtension; + UNICODE_STRING DeviceName; + PCWSTR DeviceNameSuffix; + INT Flags = 0; + int i; + + PAGED_CODE(); + + DPRINT("DriverEntry\n"); + + /* get the Kmtest device, so that we get a ResultBuffer pointer */ + RtlInitUnicodeString(&KmtestDeviceName, KMTEST_DEVICE_DRIVER_PATH); + Status = IoGetDeviceObjectPointer(&KmtestDeviceName, FILE_ALL_ACCESS, &KmtestFileObject, &KmtestDeviceObject); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to get Kmtest device object pointer\n"); + goto cleanup; + } + + Status = ObReferenceObjectByPointer(KmtestDeviceObject, FILE_ALL_ACCESS, NULL, KernelMode); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("Failed to reference Kmtest device object\n"); + goto cleanup; + } + + ObDereferenceObject(KmtestFileObject); + KmtestFileObject = NULL; + KmtestDeviceExtension = KmtestDeviceObject->DeviceExtension; + ResultBuffer = KmtestDeviceExtension->ResultBuffer; + DPRINT("KmtestDeviceObject: %p\n", (PVOID)KmtestDeviceObject); + DPRINT("KmtestDeviceExtension: %p\n", (PVOID)KmtestDeviceExtension); + DPRINT("Setting ResultBuffer: %p\n", (PVOID)ResultBuffer); + + /* call TestEntry */ + RtlInitUnicodeString(&DeviceName, DeviceNameBuffer); + DeviceName.MaximumLength = sizeof DeviceNameBuffer; + TestEntry(DriverObject, RegistryPath, &DeviceNameSuffix, &Flags); + RtlAppendUnicodeToString(&DeviceName, DeviceNameSuffix); + + /* create test device */ + if (!(Flags & TESTENTRY_NO_CREATE_DEVICE)) + { + Status = IoCreateDevice(DriverObject, 0, &DeviceName, + FILE_DEVICE_UNKNOWN, + FILE_DEVICE_SECURE_OPEN | FILE_READ_ONLY_DEVICE, + TRUE, &TestDeviceObject); + + if (!NT_SUCCESS(Status)) + { + DPRINT1("Could not create device object %wZ\n", &DeviceName); + goto cleanup; + } + + DPRINT("DriverEntry. Created DeviceObject %p\n", + TestDeviceObject); + } + + /* initialize dispatch functions */ + if (!(Flags & TESTENTRY_NO_REGISTER_UNLOAD)) + DriverObject->DriverUnload = DriverUnload; + if (!(Flags & TESTENTRY_NO_REGISTER_DISPATCH)) + for (i = 0; i <= IRP_MJ_MAXIMUM_FUNCTION; ++i) + DriverObject->MajorFunction[i] = DriverDispatch; + +cleanup: + if (TestDeviceObject && !NT_SUCCESS(Status)) + { + IoDeleteDevice(TestDeviceObject); + TestDeviceObject = NULL; + } + + if (KmtestDeviceObject && !NT_SUCCESS(Status)) + { + ObDereferenceObject(KmtestDeviceObject); + KmtestDeviceObject = NULL; + if (KmtestFileObject) + ObDereferenceObject(KmtestFileObject); + } + + return Status; +} + +/** + * @name DriverUnload + * + * Driver cleanup funtion. + * + * @param DriverObject + * Driver Object + */ +static +VOID +NTAPI +DriverUnload( + IN PDRIVER_OBJECT DriverObject) +{ + PAGED_CODE(); + + UNREFERENCED_PARAMETER(DriverObject); + + DPRINT("DriverUnload\n"); + + TestUnload(DriverObject); + + if (TestDeviceObject) + IoDeleteDevice(TestDeviceObject); + + if (KmtestDeviceObject) + ObDereferenceObject(KmtestDeviceObject); +} + +/** + * @name KmtRegisterIrpHandler + * + * Register a handler with the IRP Dispatcher. + * If multiple registered handlers match an IRP, it is unspecified which of + * them is called on IRP reception + * + * @param MajorFunction + * IRP major function code to be handled + * @param DeviceObject + * Device Object to handle IRPs for. + * Can be NULL to indicate any device object + * @param IrpHandler + * Handler function to register. + * + * @return Status + */ +NTSTATUS +KmtRegisterIrpHandler( + IN UCHAR MajorFunction, + IN PDEVICE_OBJECT DeviceObject OPTIONAL, + IN PKMT_IRP_HANDLER IrpHandler) +{ + NTSTATUS Status = STATUS_SUCCESS; + int i; + + if (MajorFunction > IRP_MJ_MAXIMUM_FUNCTION) + { + Status = STATUS_INVALID_PARAMETER_1; + goto cleanup; + } + + if (IrpHandler == NULL) + { + Status = STATUS_INVALID_PARAMETER_3; + goto cleanup; + } + + for (i = 0; i < sizeof IrpHandlers / sizeof IrpHandlers[0]; ++i) + if (IrpHandlers[i].IrpHandler == NULL) + { + IrpHandlers[i].MajorFunction = MajorFunction; + IrpHandlers[i].DeviceObject = DeviceObject; + IrpHandlers[i].IrpHandler = IrpHandler; + goto cleanup; + } + + Status = STATUS_ALLOTTED_SPACE_EXCEEDED; + +cleanup: + return Status; +} + +/** + * @name KmtUnregisterIrpHandler + * + * Unregister a handler with the IRP Dispatcher. + * Parameters must be specified exactly as in the call to + * KmtRegisterIrpHandler. Only the first matching entry will be removed + * if multiple exist + * + * @param MajorFunction + * IRP major function code of the handler to be removed + * @param DeviceObject + * Device Object to of the handler to be removed + * @param IrpHandler + * Handler function of the handler to be removed + * + * @return Status + */ +NTSTATUS +KmtUnregisterIrpHandler( + IN UCHAR MajorFunction, + IN PDEVICE_OBJECT DeviceObject OPTIONAL, + IN PKMT_IRP_HANDLER IrpHandler) +{ + NTSTATUS Status = STATUS_SUCCESS; + int i; + + for (i = 0; i < sizeof IrpHandlers / sizeof IrpHandlers[0]; ++i) + if (IrpHandlers[i].MajorFunction == MajorFunction && + IrpHandlers[i].DeviceObject == DeviceObject && + IrpHandlers[i].IrpHandler == IrpHandler) + { + IrpHandlers[i].IrpHandler = NULL; + goto cleanup; + } + + Status = STATUS_NOT_FOUND; + +cleanup: + return Status; +} + +/** + * @name DriverDispatch + * + * Driver Dispatch function + * + * @param DeviceObject + * Device Object + * @param Irp + * I/O request packet + * + * @return Status + */ +static +NTSTATUS +NTAPI +DriverDispatch( + IN PDEVICE_OBJECT DeviceObject, + IN PIRP Irp) +{ + NTSTATUS Status = STATUS_SUCCESS; + PIO_STACK_LOCATION IoStackLocation; + int i; + + PAGED_CODE(); + + IoStackLocation = IoGetCurrentIrpStackLocation(Irp); + + DPRINT("DriverDispatch: Function=%s, Device=%p\n", + IrpMajorFunctionNames[IoStackLocation->MajorFunction], + DeviceObject); + + for (i = 0; i < sizeof IrpHandlers / sizeof IrpHandlers[0]; ++i) + { + if (IrpHandlers[i].MajorFunction == IoStackLocation->MajorFunction && + (IrpHandlers[i].DeviceObject == NULL || IrpHandlers[i].DeviceObject == DeviceObject) && + IrpHandlers[i].IrpHandler != NULL) + return IrpHandlers[i].IrpHandler(DeviceObject, Irp, IoStackLocation); + } + + /* default handler for DeviceControl */ + if (IoStackLocation->MajorFunction == IRP_MJ_DEVICE_CONTROL || + IoStackLocation->MajorFunction == IRP_MJ_INTERNAL_DEVICE_CONTROL) + return DeviceControlHandler(DeviceObject, Irp, IoStackLocation); + + /* default handler */ + Irp->IoStatus.Status = Status; + Irp->IoStatus.Information = 0; + + IoCompleteRequest(Irp, IO_NO_INCREMENT); + + return Status; +} + +/** + * @name KmtRegisterMessageHandler + * + * Register a handler with the DeviceControl Dispatcher. + * If multiple registered handlers match a message, it is unspecified which of + * them is called on message reception. + * NOTE: message handlers registered with this function will not be called + * if a custom IRP handler matching the corresponding IRP is installed! + * + * @param ControlCode + * Control code to be handled, as passed by the application. + * Can be 0 to indicate any control code + * @param DeviceObject + * Device Object to handle IRPs for. + * Can be NULL to indicate any device object + * @param MessageHandler + * Handler function to register. + * + * @return Status + */ +NTSTATUS +KmtRegisterMessageHandler( + IN ULONG ControlCode OPTIONAL, + IN PDEVICE_OBJECT DeviceObject OPTIONAL, + IN PKMT_MESSAGE_HANDLER MessageHandler) +{ + NTSTATUS Status = STATUS_SUCCESS; + int i; + + if (ControlCode >= 0x400) + { + Status = STATUS_INVALID_PARAMETER_1; + goto cleanup; + } + + if (MessageHandler == NULL) + { + Status = STATUS_INVALID_PARAMETER_2; + goto cleanup; + } + + for (i = 0; i < sizeof MessageHandlers / sizeof MessageHandlers[0]; ++i) + if (MessageHandlers[i].MessageHandler == NULL) + { + MessageHandlers[i].ControlCode = ControlCode; + MessageHandlers[i].DeviceObject = DeviceObject; + MessageHandlers[i].MessageHandler = MessageHandler; + goto cleanup; + } + + Status = STATUS_ALLOTTED_SPACE_EXCEEDED; + +cleanup: + return Status; +} + +/** + * @name KmtUnregisterMessageHandler + * + * Unregister a handler with the DeviceControl Dispatcher. + * Parameters must be specified exactly as in the call to + * KmtRegisterMessageHandler. Only the first matching entry will be removed + * if multiple exist + * + * @param ControlCode + * Control code of the handler to be removed + * @param DeviceObject + * Device Object to of the handler to be removed + * @param MessageHandler + * Handler function of the handler to be removed + * + * @return Status + */ +NTSTATUS +KmtUnregisterMessageHandler( + IN ULONG ControlCode OPTIONAL, + IN PDEVICE_OBJECT DeviceObject OPTIONAL, + IN PKMT_MESSAGE_HANDLER MessageHandler) +{ + NTSTATUS Status = STATUS_SUCCESS; + int i; + + for (i = 0; i < sizeof MessageHandlers / sizeof MessageHandlers[0]; ++i) + if (MessageHandlers[i].ControlCode == ControlCode && + MessageHandlers[i].DeviceObject == DeviceObject && + MessageHandlers[i].MessageHandler == MessageHandler) + { + MessageHandlers[i].MessageHandler = NULL; + goto cleanup; + } + + Status = STATUS_NOT_FOUND; + +cleanup: + return Status; +} + +/** + * @name DeviceControlHandler + * + * Default IRP_MJ_DEVICE_CONTROL/IRP_MJ_INTERNAL_DEVICE_CONTROL handler + * + * @param DeviceObject + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler + * @param Irp + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler, except for passing it to + * IoGetCurrentStackLocation + * @param IoStackLocation + * Device Object. + * This is guaranteed not to have been touched by the dispatch function + * before the call to the IRP handler + * + * @return Status + */ +static +NTSTATUS +DeviceControlHandler( + IN PDEVICE_OBJECT DeviceObject, + IN PIRP Irp, + IN PIO_STACK_LOCATION IoStackLocation) +{ + NTSTATUS Status = STATUS_SUCCESS; + ULONG ControlCode = (IoStackLocation->Parameters.DeviceIoControl.IoControlCode & 0x00000FFC) >> 2; + ULONG OutLength = IoStackLocation->Parameters.DeviceIoControl.OutputBufferLength; + int i; + + for (i = 0; i < sizeof MessageHandlers / sizeof MessageHandlers[0]; ++i) + { + if ((MessageHandlers[i].ControlCode == 0 || + MessageHandlers[i].ControlCode == ControlCode) && + (MessageHandlers[i].DeviceObject == NULL || MessageHandlers[i].DeviceObject == DeviceObject) && + MessageHandlers[i].MessageHandler != NULL) + { + Status = MessageHandlers[i].MessageHandler(DeviceObject, ControlCode, Irp->AssociatedIrp.SystemBuffer, + IoStackLocation->Parameters.DeviceIoControl.InputBufferLength, + &OutLength); + break; + } + } + + Irp->IoStatus.Status = Status; + Irp->IoStatus.Information = OutLength; + + IoCompleteRequest(Irp, IO_NO_INCREMENT); + + return Status; +}