diff --git a/rostests/kmtests/CMakeLists.txt b/rostests/kmtests/CMakeLists.txt index 556a7b4fd9e..6cd30f6aa50 100644 --- a/rostests/kmtests/CMakeLists.txt +++ b/rostests/kmtests/CMakeLists.txt @@ -117,7 +117,7 @@ list(APPEND KMTEST_SOURCE add_executable(kmtest ${KMTEST_SOURCE}) set_module_type(kmtest win32cui) target_link_libraries(kmtest ${PSEH_LIB}) -add_importlibs(kmtest advapi32 msvcrt kernel32 ntdll) +add_importlibs(kmtest advapi32 ws2_32 msvcrt kernel32 ntdll) add_target_compile_definitions(kmtest KMT_USER_MODE) #add_pch(kmtest include/kmt_test.h) set_target_properties(kmtest PROPERTIES OUTPUT_NAME "kmtest_") diff --git a/rostests/kmtests/kmtest/testlist.c b/rostests/kmtests/kmtest/testlist.c index 22c7d8c43d4..9688ec70e03 100644 --- a/rostests/kmtests/kmtest/testlist.c +++ b/rostests/kmtests/kmtest/testlist.c @@ -19,6 +19,7 @@ KMT_TESTFUNC Test_RtlSplayTree; KMT_TESTFUNC Test_RtlUnicodeString; KMT_TESTFUNC Test_TcpIpIoctl; KMT_TESTFUNC Test_TcpIpTdi; +KMT_TESTFUNC Test_TcpIpConnect; /* tests with a leading '-' will not be listed */ const KMT_TEST TestList[] = @@ -34,5 +35,6 @@ const KMT_TEST TestList[] = { "RtlSplayTree", Test_RtlSplayTree }, { "RtlUnicodeString", Test_RtlUnicodeString }, { "TcpIpTdi", Test_TcpIpTdi }, + { "TcpIpConnect", Test_TcpIpConnect }, { NULL, NULL }, }; diff --git a/rostests/kmtests/tcpip/CMakeLists.txt b/rostests/kmtests/tcpip/CMakeLists.txt index 2995dcc9f09..33614812352 100644 --- a/rostests/kmtests/tcpip/CMakeLists.txt +++ b/rostests/kmtests/tcpip/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(../include) list(APPEND TCPIP_TEST_DRV_SOURCE ../kmtest_drv/kmtest_standalone.c + connect.c tdi.c TcpIp_drv.c) diff --git a/rostests/kmtests/tcpip/TcpIp_drv.c b/rostests/kmtests/tcpip/TcpIp_drv.c index 6877e52751c..6635219b1b2 100644 --- a/rostests/kmtests/tcpip/TcpIp_drv.c +++ b/rostests/kmtests/tcpip/TcpIp_drv.c @@ -9,6 +9,7 @@ #include "tcpip.h" extern KMT_MESSAGE_HANDLER TestTdi; +extern KMT_MESSAGE_HANDLER TestConnect; static struct { @@ -16,7 +17,8 @@ static struct PKMT_MESSAGE_HANDLER Handler; } MessageHandlers[] = { - { IOCTL_TEST_TDI, TestTdi }, + { IOCTL_TEST_TDI, TestTdi }, + { IOCTL_TEST_CONNECT, TestConnect }, }; NTSTATUS diff --git a/rostests/kmtests/tcpip/TcpIp_user.c b/rostests/kmtests/tcpip/TcpIp_user.c index 2b5da02df6d..cf699810a0c 100644 --- a/rostests/kmtests/tcpip/TcpIp_user.c +++ b/rostests/kmtests/tcpip/TcpIp_user.c @@ -6,6 +6,7 @@ */ #include +#include #include "tcpip.h" @@ -35,3 +36,72 @@ START_TEST(TcpIpTdi) UnloadTcpIpTestDriver(); } + +static +DWORD +WINAPI +AcceptProc( + _In_ LPVOID Parameter) +{ + WORD WinsockVersion; + WSADATA WsaData; + int Error; + SOCKET ListenSocket, AcceptSocket; + struct sockaddr_in ListenAddress, AcceptAddress; + int AcceptAddressLength; + HANDLE ReadyToConnectEvent = (HANDLE)Parameter; + + /* Initialize winsock */ + WinsockVersion = MAKEWORD(2, 0); + Error = WSAStartup(WinsockVersion, &WsaData); + ok(Error == 0, ""); + + ListenSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ok_bool_true(ListenSocket != INVALID_SOCKET, "socket failed"); + + ZeroMemory(&ListenAddress, sizeof(ListenAddress)); + ListenAddress.sin_addr.S_un.S_addr = inet_addr("127.0.0.1"); + ListenAddress.sin_port = htons(TEST_CONNECT_SERVER_PORT); + ListenAddress.sin_family = AF_INET; + + Error = bind(ListenSocket, (struct sockaddr*)&ListenAddress, sizeof(ListenAddress)); + ok(Error == 0, ""); + + Error = listen(ListenSocket, 1); + ok(Error == 0, ""); + + SetEvent(ReadyToConnectEvent); + + AcceptAddressLength = sizeof(AcceptAddress); + AcceptSocket = accept(ListenSocket, (struct sockaddr*)&AcceptAddress, &AcceptAddressLength); + ok(AcceptSocket != INVALID_SOCKET, "\n"); + ok_eq_long(AcceptAddressLength, sizeof(AcceptAddress)); + ok_eq_hex(AcceptAddress.sin_addr.S_un.S_addr, inet_addr("127.0.0.1")); + ok_eq_hex(AcceptAddress.sin_port, ntohs(TEST_CONNECT_CLIENT_PORT)); + + return 0; +} + +START_TEST(TcpIpConnect) +{ + HANDLE AcceptThread; + HANDLE ReadyToConnectEvent; + + ReadyToConnectEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + ok(ReadyToConnectEvent != NULL, "\n"); + + AcceptThread = CreateThread(NULL, 0, AcceptProc, (PVOID)ReadyToConnectEvent, 0, NULL); + ok(AcceptThread != NULL, ""); + + WaitForSingleObject(ReadyToConnectEvent, INFINITE); + + LoadTcpIpTestDriver(); + + ok(KmtSendToDriver(IOCTL_TEST_CONNECT) == ERROR_SUCCESS, "\n"); + + WaitForSingleObject(AcceptThread, INFINITE); + + UnloadTcpIpTestDriver(); + + WSACleanup(); +} diff --git a/rostests/kmtests/tcpip/connect.c b/rostests/kmtests/tcpip/connect.c new file mode 100644 index 00000000000..366756f1cb3 --- /dev/null +++ b/rostests/kmtests/tcpip/connect.c @@ -0,0 +1,272 @@ +/* + * PROJECT: ReactOS kernel-mode tests + * LICENSE: LGPLv2+ - See COPYING.LIB in the top level directory + * PURPOSE: Kernel-Mode Test Suite for TCPIP.sys + * PROGRAMMER: Jérôme Gardou + */ + +#include +#include +#include + +#include + +#include "tcpip.h" + +#define TAG_TEST 'tseT' + +#if BYTE_ORDER == LITTLE_ENDIAN +USHORT +htons(USHORT x) +{ + return ((x & 0x00FF) << 8) | ((x & 0xFF00) >> 8); +} +#else +#define htons(x) (x) +#endif + +static +NTSTATUS +NTAPI +IrpCompletionRoutine( + _In_ PDEVICE_OBJECT DeviceObject, + _In_ PIRP Irp, + _In_ PVOID Context) +{ + UNREFERENCED_PARAMETER(DeviceObject); + UNREFERENCED_PARAMETER(Irp); + + KeSetEvent((PKEVENT)Context, IO_NETWORK_INCREMENT, FALSE); + + return STATUS_MORE_PROCESSING_REQUIRED; +} + +static +VOID +TestTcpConnect(void) +{ + PIRP Irp; + HANDLE AddressHandle, ConnectionHandle; + FILE_OBJECT* ConnectionFileObject; + DEVICE_OBJECT* DeviceObject; + UNICODE_STRING TcpDeviceName = RTL_CONSTANT_STRING(L"\\Device\\Tcp"); + NTSTATUS Status; + PFILE_FULL_EA_INFORMATION FileInfo; + TA_IP_ADDRESS* IpAddress; + TA_IP_ADDRESS ConnectAddress, ReturnAddress; + OBJECT_ATTRIBUTES ObjectAttributes; + IO_STATUS_BLOCK StatusBlock; + ULONG FileInfoSize; + IN_ADDR InAddr; + LPCWSTR AddressTerminator; + CONNECTION_CONTEXT ConnectionContext = (CONNECTION_CONTEXT)0xC0CAC01A; + KEVENT Event; + TDI_CONNECTION_INFORMATION RequestInfo, ReturnInfo; + + /* Create a TCP address file */ + FileInfoSize = FIELD_OFFSET(FILE_FULL_EA_INFORMATION, EaName[TDI_TRANSPORT_ADDRESS_LENGTH]) + 1 + sizeof(TA_IP_ADDRESS); + FileInfo = ExAllocatePoolWithTag(NonPagedPool, + FileInfoSize, + TAG_TEST); + ok(FileInfo != NULL, ""); + RtlZeroMemory(FileInfo, FileInfoSize); + + FileInfo->EaNameLength = TDI_TRANSPORT_ADDRESS_LENGTH; + FileInfo->EaValueLength = sizeof(TA_IP_ADDRESS); + RtlCopyMemory(&FileInfo->EaName[0], TdiTransportAddress, TDI_TRANSPORT_ADDRESS_LENGTH); + + IpAddress = (PTA_IP_ADDRESS)(&FileInfo->EaName[TDI_TRANSPORT_ADDRESS_LENGTH + 1]); + IpAddress->TAAddressCount = 1; + IpAddress->Address[0].AddressType = TDI_ADDRESS_TYPE_IP; + IpAddress->Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP; + IpAddress->Address[0].Address[0].sin_port = htons(TEST_CONNECT_CLIENT_PORT); + Status = RtlIpv4StringToAddressW(L"127.0.0.1", TRUE, &AddressTerminator, &InAddr); + ok_eq_hex(Status, STATUS_SUCCESS); + IpAddress->Address[0].Address[0].in_addr = InAddr.S_un.S_addr; + + InitializeObjectAttributes(&ObjectAttributes, + &TcpDeviceName, + OBJ_CASE_INSENSITIVE | OBJ_KERNEL_HANDLE, + NULL, + NULL); + + Status = ZwCreateFile( + &AddressHandle, + GENERIC_READ | GENERIC_WRITE, + &ObjectAttributes, + &StatusBlock, + 0, + FILE_ATTRIBUTE_NORMAL, + FILE_SHARE_READ | FILE_SHARE_WRITE, + FILE_OPEN_IF, + 0L, + FileInfo, + FileInfoSize); + ok_eq_hex(Status, STATUS_SUCCESS); + + ExFreePoolWithTag(FileInfo, TAG_TEST); + + /* Create a TCP connection file */ + FileInfoSize = FIELD_OFFSET(FILE_FULL_EA_INFORMATION, EaName[TDI_CONNECTION_CONTEXT_LENGTH]) + 1 + sizeof(CONNECTION_CONTEXT); + FileInfo = ExAllocatePoolWithTag(NonPagedPool, + FileInfoSize, + TAG_TEST); + ok(FileInfo != NULL, ""); + RtlZeroMemory(FileInfo, FileInfoSize); + + FileInfo->EaNameLength = TDI_CONNECTION_CONTEXT_LENGTH; + FileInfo->EaValueLength = sizeof(CONNECTION_CONTEXT); + RtlCopyMemory(&FileInfo->EaName[0], TdiConnectionContext, TDI_CONNECTION_CONTEXT_LENGTH); + *((CONNECTION_CONTEXT*)&FileInfo->EaName[TDI_CONNECTION_CONTEXT_LENGTH + 1]) = ConnectionContext; + + Status = ZwCreateFile( + &ConnectionHandle, + GENERIC_READ | GENERIC_WRITE, + &ObjectAttributes, + &StatusBlock, + 0, + FILE_ATTRIBUTE_NORMAL, + FILE_SHARE_READ | FILE_SHARE_WRITE, + FILE_OPEN_IF, + 0L, + FileInfo, + FileInfoSize); + ok_eq_hex(Status, STATUS_SUCCESS); + + ExFreePoolWithTag(FileInfo, TAG_TEST); + + /* Get the file and device object for the upcoming IRPs */ + Status = ObReferenceObjectByHandle( + ConnectionHandle, + GENERIC_READ, + *IoFileObjectType, + KernelMode, + (PVOID*)&ConnectionFileObject, + NULL); + ok_eq_hex(Status, STATUS_SUCCESS); + DeviceObject = IoGetRelatedDeviceObject(ConnectionFileObject); + ok(DeviceObject != NULL, "Device object is NULL!\n"); + + /* Associate the connection file and the address */ + KeInitializeEvent(&Event, NotificationEvent, FALSE); + Irp = IoAllocateIrp(DeviceObject->StackSize, FALSE); + ok(Irp != NULL, "IoAllocateIrp failed.\n"); + + TdiBuildAssociateAddress(Irp, DeviceObject, ConnectionFileObject, NULL, NULL, AddressHandle); + IoSetCompletionRoutine(Irp, IrpCompletionRoutine, &Event, TRUE, TRUE, TRUE); + + Status = IoCallDriver(DeviceObject, Irp); + if (Status == STATUS_PENDING) + { + trace("Associate address IRP is pending.\n"); + KeWaitForSingleObject( + &Event, + Executive, + KernelMode, + FALSE, + NULL); + Status = Irp->IoStatus.Status; + } + ok_eq_hex(Status, STATUS_SUCCESS); + IoFreeIrp(Irp); + + + KeClearEvent(&Event); + + /* Build the connect IRP. */ + Irp = IoAllocateIrp(DeviceObject->StackSize, FALSE); + ok(Irp != NULL, "IoAllocateIrp failed.\n"); + + /* Prepare the request */ + RtlZeroMemory(&RequestInfo, sizeof(RequestInfo)); + RtlZeroMemory(&ConnectAddress, sizeof(ConnectAddress)); + RequestInfo.RemoteAddressLength = sizeof(TA_IP_ADDRESS); + RequestInfo.RemoteAddress = &ConnectAddress; + ConnectAddress.TAAddressCount = 1; + ConnectAddress.Address[0].AddressType = TDI_ADDRESS_TYPE_IP; + ConnectAddress.Address[0].AddressLength = TDI_ADDRESS_LENGTH_IP; + ConnectAddress.Address[0].Address[0].sin_port = htons(TEST_CONNECT_SERVER_PORT); + Status = RtlIpv4StringToAddressW(L"127.0.0.1", TRUE, &AddressTerminator, &InAddr); + ConnectAddress.Address[0].Address[0].in_addr = InAddr.S_un.S_addr; + + /* See what we will get in exchange */ + RtlZeroMemory(&ReturnInfo, sizeof(ReturnInfo)); + RtlZeroMemory(&ReturnAddress, sizeof(ReturnAddress)); + ReturnInfo.RemoteAddressLength = sizeof(TA_IP_ADDRESS); + ReturnInfo.RemoteAddress = &ReturnAddress; + + TdiBuildConnect(Irp, + DeviceObject, + ConnectionFileObject, + NULL, + NULL, + NULL, + &RequestInfo, + &ReturnInfo); + IoSetCompletionRoutine(Irp, IrpCompletionRoutine, &Event, TRUE, TRUE, TRUE); + + Status = IoCallDriver(DeviceObject, Irp); + if (Status == STATUS_PENDING) + { + trace("Connect IRP is pending.\n"); + KeWaitForSingleObject( + &Event, + Executive, + KernelMode, + FALSE, + NULL); + Status = Irp->IoStatus.Status; + trace("Connect IRP completed.\n"); + } + ok_eq_hex(Status, STATUS_SUCCESS); + IoFreeIrp(Irp); + + /* The IRP doesn't touch the return info */ + ok_eq_long(ReturnInfo.RemoteAddressLength, sizeof(TA_IP_ADDRESS)); + ok_eq_pointer(ReturnInfo.RemoteAddress, &ReturnAddress); + ok_eq_long(ReturnInfo.OptionsLength, 0); + ok_eq_pointer(ReturnInfo.Options, NULL); + ok_eq_long(ReturnInfo.UserDataLength, 0); + ok_eq_pointer(ReturnInfo.UserData, NULL); + + ok_eq_long(ReturnAddress.TAAddressCount, 0); + ok_eq_hex(ReturnAddress.Address[0].AddressType, 0); + ok_eq_hex(ReturnAddress.Address[0].AddressLength, 0); + ok_eq_hex(ReturnAddress.Address[0].Address[0].sin_port, 0); + ok_eq_hex(ReturnAddress.Address[0].Address[0].in_addr, 0); + + ObDereferenceObject(ConnectionFileObject); + + ZwClose(ConnectionHandle); + ZwClose(AddressHandle); +} + +static KSTART_ROUTINE RunTest; +static +VOID +NTAPI +RunTest( + _In_ PVOID Context) +{ + UNREFERENCED_PARAMETER(Context); + + TestTcpConnect(); +} + +KMT_MESSAGE_HANDLER TestConnect; +NTSTATUS +TestConnect( + _In_ PDEVICE_OBJECT DeviceObject, + _In_ ULONG ControlCode, + _In_opt_ PVOID Buffer, + _In_ SIZE_T InLength, + _Inout_ PSIZE_T OutLength +) +{ + PKTHREAD Thread; + + Thread = KmtStartThread(RunTest, NULL); + KmtFinishThread(Thread, NULL); + + return STATUS_SUCCESS; +} diff --git a/rostests/kmtests/tcpip/tcpip.h b/rostests/kmtests/tcpip/tcpip.h index 45a5c504a0a..ba95ac24a0a 100644 --- a/rostests/kmtests/tcpip/tcpip.h +++ b/rostests/kmtests/tcpip/tcpip.h @@ -1,2 +1,7 @@ #define IOCTL_TEST_TDI 1 +#define IOCTL_TEST_CONNECT 2 + +/* For the TDI_CONNECT test */ +#define TEST_CONNECT_SERVER_PORT 12345 +#define TEST_CONNECT_CLIENT_PORT 54321