reactos/modules/rostests/apitests/afd/AfdHelpers.c
Thomas Faber 680d69d373
[AFD_APITEST] Introduce a test for directly creating and using sockets via AFD. CORE-9810
The initial tests in send.c validate correct behavior of send/sendto on
disconnected sockets (CORE-9810), fixed in r68129.
However, the helper functions are generic, so they can be used for additional
tests against AFD. Because AFD's create packet structure changes between
Windows versions, the functions check the OS version to determine the right
layout.
Tests succeed on Win2003 as well as Win10.
2018-03-05 14:52:56 +01:00

385 lines
12 KiB
C

/*
* PROJECT: ReactOS API Tests
* LICENSE: LGPL-2.1+ (https://spdx.org/licenses/LGPL-2.1+)
* PURPOSE: Utility function definitions for calling AFD
* COPYRIGHT: Copyright 2015-2018 Thomas Faber (thomas.faber@reactos.org)
*/
#include "precomp.h"
#define DD_UDP_DEVICE_NAME L"\\Device\\Udp"
typedef struct _AFD_CREATE_PACKET_NT6 {
DWORD EndpointFlags;
DWORD GroupID;
DWORD AddressFamily;
DWORD SocketType;
DWORD Protocol;
DWORD SizeOfTransportName;
WCHAR TransportName[1];
} AFD_CREATE_PACKET_NT6, *PAFD_CREATE_PACKET_NT6;
NTSTATUS
AfdCreateSocket(
_Out_ PHANDLE SocketHandle,
_In_ int AddressFamily,
_In_ int SocketType,
_In_ int Protocol)
{
NTSTATUS Status;
OBJECT_ATTRIBUTES ObjectAttributes;
IO_STATUS_BLOCK IoStatus;
PFILE_FULL_EA_INFORMATION EaBuffer = NULL;
ULONG EaLength;
PAFD_CREATE_PACKET AfdPacket;
PAFD_CREATE_PACKET_NT6 AfdPacket6;
ULONG SizeOfPacket;
ANSI_STRING EaName = RTL_CONSTANT_STRING(AfdCommand);
UNICODE_STRING TcpTransportName = RTL_CONSTANT_STRING(DD_TCP_DEVICE_NAME);
UNICODE_STRING UdpTransportName = RTL_CONSTANT_STRING(DD_UDP_DEVICE_NAME);
UNICODE_STRING TransportName = SocketType == SOCK_STREAM ? TcpTransportName : UdpTransportName;
UNICODE_STRING DeviceName = RTL_CONSTANT_STRING(L"\\Device\\Afd\\Endpoint");
*SocketHandle = NULL;
if (LOBYTE(LOWORD(GetVersion())) >= 6)
{
SizeOfPacket = FIELD_OFFSET(AFD_CREATE_PACKET_NT6, TransportName) + TransportName.Length + sizeof(UNICODE_NULL);
}
else
{
SizeOfPacket = FIELD_OFFSET(AFD_CREATE_PACKET, TransportName) + TransportName.Length + sizeof(UNICODE_NULL);
}
EaLength = SizeOfPacket + FIELD_OFFSET(FILE_FULL_EA_INFORMATION, EaName) + EaName.Length + sizeof(ANSI_NULL);
/* Set up EA Buffer */
EaBuffer = RtlAllocateHeap(RtlGetProcessHeap(), HEAP_ZERO_MEMORY, EaLength);
if (!EaBuffer)
{
return STATUS_INSUFFICIENT_RESOURCES;
}
EaBuffer->NextEntryOffset = 0;
EaBuffer->Flags = 0;
EaBuffer->EaNameLength = EaName.Length;
RtlCopyMemory(EaBuffer->EaName,
EaName.Buffer,
EaName.Length + sizeof(ANSI_NULL));
EaBuffer->EaValueLength = SizeOfPacket;
if (LOBYTE(LOWORD(GetVersion())) >= 6)
{
AfdPacket6 = (PAFD_CREATE_PACKET_NT6)(EaBuffer->EaName + EaBuffer->EaNameLength + sizeof(ANSI_NULL));
AfdPacket6->GroupID = 0;
if (SocketType == SOCK_DGRAM)
{
AfdPacket6->EndpointFlags = AFD_ENDPOINT_CONNECTIONLESS;
}
else if (SocketType == SOCK_STREAM)
{
AfdPacket6->EndpointFlags = AFD_ENDPOINT_MESSAGE_ORIENTED;
}
AfdPacket6->AddressFamily = AddressFamily;
AfdPacket6->SocketType = SocketType;
AfdPacket6->Protocol = Protocol;
AfdPacket6->SizeOfTransportName = TransportName.Length;
RtlCopyMemory(AfdPacket6->TransportName,
TransportName.Buffer,
TransportName.Length + sizeof(UNICODE_NULL));
}
else
{
AfdPacket = (PAFD_CREATE_PACKET)(EaBuffer->EaName + EaBuffer->EaNameLength + sizeof(ANSI_NULL));
AfdPacket->GroupID = 0;
if (SocketType == SOCK_DGRAM)
{
AfdPacket->EndpointFlags = AFD_ENDPOINT_CONNECTIONLESS;
}
else if (SocketType == SOCK_STREAM)
{
AfdPacket->EndpointFlags = AFD_ENDPOINT_MESSAGE_ORIENTED;
}
AfdPacket->SizeOfTransportName = TransportName.Length;
RtlCopyMemory(AfdPacket->TransportName,
TransportName.Buffer,
TransportName.Length + sizeof(UNICODE_NULL));
}
InitializeObjectAttributes(&ObjectAttributes,
&DeviceName,
OBJ_CASE_INSENSITIVE | OBJ_INHERIT,
0,
0);
Status = NtCreateFile(SocketHandle,
GENERIC_READ | GENERIC_WRITE | SYNCHRONIZE,
&ObjectAttributes,
&IoStatus,
NULL,
0,
FILE_SHARE_READ | FILE_SHARE_WRITE,
FILE_OPEN_IF,
0,
EaBuffer,
EaLength);
RtlFreeHeap(RtlGetProcessHeap(), 0, EaBuffer);
return Status;
}
NTSTATUS
AfdBind(
_In_ HANDLE SocketHandle,
_In_ const struct sockaddr *Address,
_In_ ULONG AddressLength)
{
NTSTATUS Status;
IO_STATUS_BLOCK IoStatus;
PAFD_BIND_DATA BindInfo;
ULONG BindInfoLength;
HANDLE Event;
Status = NtCreateEvent(&Event,
EVENT_ALL_ACCESS,
NULL,
NotificationEvent,
FALSE);
if (!NT_SUCCESS(Status))
{
return Status;
}
BindInfoLength = FIELD_OFFSET(AFD_BIND_DATA, Address.Address[0].Address) +
AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
BindInfo = RtlAllocateHeap(RtlGetProcessHeap(),
0,
BindInfoLength);
if (!BindInfo)
{
NtClose(Event);
return STATUS_INSUFFICIENT_RESOURCES;
}
BindInfo->ShareType = AFD_SHARE_UNIQUE;
BindInfo->Address.TAAddressCount = 1;
BindInfo->Address.Address[0].AddressType = Address->sa_family;
BindInfo->Address.Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
RtlCopyMemory(&BindInfo->Address.Address[0].Address,
Address->sa_data,
BindInfo->Address.Address[0].AddressLength);
Status = NtDeviceIoControlFile(SocketHandle,
Event,
NULL,
NULL,
&IoStatus,
IOCTL_AFD_BIND,
BindInfo,
BindInfoLength,
BindInfo,
BindInfoLength);
if (Status == STATUS_PENDING)
{
NtWaitForSingleObject(Event, FALSE, NULL);
Status = IoStatus.Status;
}
RtlFreeHeap(RtlGetProcessHeap(), 0, BindInfo);
NtClose(Event);
return Status;
}
NTSTATUS
AfdConnect(
_In_ HANDLE SocketHandle,
_In_ const struct sockaddr *Address,
_In_ ULONG AddressLength)
{
NTSTATUS Status;
IO_STATUS_BLOCK IoStatus;
PAFD_CONNECT_INFO ConnectInfo;
ULONG ConnectInfoLength;
HANDLE Event;
Status = NtCreateEvent(&Event,
EVENT_ALL_ACCESS,
NULL,
NotificationEvent,
FALSE);
if (!NT_SUCCESS(Status))
{
return Status;
}
ASSERT(FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address) == 20);
ConnectInfoLength = FIELD_OFFSET(AFD_CONNECT_INFO, RemoteAddress.Address[0].Address) +
AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
ConnectInfo = RtlAllocateHeap(RtlGetProcessHeap(),
0,
ConnectInfoLength);
if (!ConnectInfo)
{
NtClose(Event);
return STATUS_INSUFFICIENT_RESOURCES;
}
ConnectInfo->UseSAN = FALSE;
ConnectInfo->Root = 0;
ConnectInfo->Unknown = 0;
ConnectInfo->RemoteAddress.TAAddressCount = 1;
ConnectInfo->RemoteAddress.Address[0].AddressType = Address->sa_family;
ConnectInfo->RemoteAddress.Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
RtlCopyMemory(&ConnectInfo->RemoteAddress.Address[0].Address,
Address->sa_data,
ConnectInfo->RemoteAddress.Address[0].AddressLength);
Status = NtDeviceIoControlFile(SocketHandle,
Event,
NULL,
NULL,
&IoStatus,
IOCTL_AFD_CONNECT,
ConnectInfo,
ConnectInfoLength,
NULL,
0);
if (Status == STATUS_PENDING)
{
NtWaitForSingleObject(Event, FALSE, NULL);
Status = IoStatus.Status;
}
RtlFreeHeap(RtlGetProcessHeap(), 0, ConnectInfo);
NtClose(Event);
return Status;
}
NTSTATUS
AfdSend(
_In_ HANDLE SocketHandle,
_In_ const void *Buffer,
_In_ ULONG BufferLength)
{
NTSTATUS Status;
IO_STATUS_BLOCK IoStatus;
AFD_SEND_INFO SendInfo;
HANDLE Event;
AFD_WSABUF AfdBuffer;
Status = NtCreateEvent(&Event,
EVENT_ALL_ACCESS,
NULL,
NotificationEvent,
FALSE);
if (!NT_SUCCESS(Status))
{
return Status;
}
AfdBuffer.buf = (PVOID)Buffer;
AfdBuffer.len = BufferLength;
SendInfo.BufferArray = &AfdBuffer;
SendInfo.BufferCount = 1;
SendInfo.TdiFlags = 0;
SendInfo.AfdFlags = 0;
Status = NtDeviceIoControlFile(SocketHandle,
Event,
NULL,
NULL,
&IoStatus,
IOCTL_AFD_SEND,
&SendInfo,
sizeof(SendInfo),
NULL,
0);
if (Status == STATUS_PENDING)
{
NtWaitForSingleObject(Event, FALSE, NULL);
Status = IoStatus.Status;
}
NtClose(Event);
return Status;
}
NTSTATUS
AfdSendTo(
_In_ HANDLE SocketHandle,
_In_ const void *Buffer,
_In_ ULONG BufferLength,
_In_ const struct sockaddr *Address,
_In_ ULONG AddressLength)
{
NTSTATUS Status;
IO_STATUS_BLOCK IoStatus;
AFD_SEND_INFO_UDP SendInfo;
HANDLE Event;
AFD_WSABUF AfdBuffer;
PTRANSPORT_ADDRESS TransportAddress;
ULONG TransportAddressLength;
Status = NtCreateEvent(&Event,
EVENT_ALL_ACCESS,
NULL,
NotificationEvent,
FALSE);
if (!NT_SUCCESS(Status))
{
return Status;
}
TransportAddressLength = FIELD_OFFSET(TRANSPORT_ADDRESS, Address[0].Address) +
AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
TransportAddress = RtlAllocateHeap(RtlGetProcessHeap(),
0,
TransportAddressLength);
if (!TransportAddress)
{
NtClose(Event);
return STATUS_INSUFFICIENT_RESOURCES;
}
TransportAddress->TAAddressCount = 1;
TransportAddress->Address[0].AddressType = Address->sa_family;
TransportAddress->Address[0].AddressLength = AddressLength - FIELD_OFFSET(struct sockaddr, sa_data);
RtlCopyMemory(&TransportAddress->Address[0].Address,
Address->sa_data,
TransportAddress->Address[0].AddressLength);
AfdBuffer.buf = (PVOID)Buffer;
AfdBuffer.len = BufferLength;
RtlZeroMemory(&SendInfo, sizeof(SendInfo));
SendInfo.BufferArray = &AfdBuffer;
SendInfo.BufferCount = 1;
SendInfo.AfdFlags = 0;
SendInfo.TdiConnection.RemoteAddress = TransportAddress;
SendInfo.TdiConnection.RemoteAddressLength = TransportAddressLength;
Status = NtDeviceIoControlFile(SocketHandle,
Event,
NULL,
NULL,
&IoStatus,
IOCTL_AFD_SEND_DATAGRAM,
&SendInfo,
sizeof(SendInfo),
NULL,
0);
if (Status == STATUS_PENDING)
{
NtWaitForSingleObject(Event, FALSE, NULL);
Status = IoStatus.Status;
}
RtlFreeHeap(RtlGetProcessHeap(), 0, TransportAddress);
NtClose(Event);
return Status;
}