reactos/dll/win32/msafd/misc/helpers.c

521 lines
17 KiB
C

/*
* COPYRIGHT: See COPYING in the top level directory
* PROJECT: ReactOS Ancillary Function Driver DLL
* FILE: dll/win32/msafd/misc/helpers.c
* PURPOSE: Helper DLL management
* PROGRAMMERS: Casper S. Hornstrup (chorns@users.sourceforge.net)
* Alex Ionescu (alex@relsoft.net)
* REVISIONS:
* CSH 01/09-2000 Created
* Alex 16/07/2004 - Complete Rewrite
*/
#include <msafd.h>
#include <winreg.h>
CRITICAL_SECTION HelperDLLDatabaseLock;
LIST_ENTRY HelperDLLDatabaseListHead;
INT
SockGetTdiName(
PINT AddressFamily,
PINT SocketType,
PINT Protocol,
GROUP Group,
DWORD Flags,
PUNICODE_STRING TransportName,
PVOID *HelperDllContext,
PHELPER_DATA *HelperDllData,
PDWORD Events)
{
PHELPER_DATA HelperData;
PWSTR Transports;
PWSTR Transport;
PWINSOCK_MAPPING Mapping;
PLIST_ENTRY Helpers;
INT Status;
TRACE("AddressFamily %p, SocketType %p, Protocol %p, Group %u, Flags %lx, TransportName %p, HelperDllContext %p, HelperDllData %p, Events %p\n",
AddressFamily, SocketType, Protocol, Group, Flags, TransportName, HelperDllContext, HelperDllData, Events);
/* Check in our Current Loaded Helpers */
for (Helpers = SockHelpersListHead.Flink;
Helpers != &SockHelpersListHead;
Helpers = Helpers->Flink ) {
HelperData = CONTAINING_RECORD(Helpers, HELPER_DATA, Helpers);
/* See if this Mapping works for us */
if (SockIsTripleInMapping (HelperData->Mapping,
*AddressFamily,
*SocketType,
*Protocol)) {
/* Call the Helper Dll function get the Transport Name */
if (HelperData->WSHOpenSocket2 == NULL ) {
/* DLL Doesn't support WSHOpenSocket2, call the old one */
HelperData->WSHOpenSocket(AddressFamily,
SocketType,
Protocol,
TransportName,
HelperDllContext,
Events
);
} else {
HelperData->WSHOpenSocket2(AddressFamily,
SocketType,
Protocol,
Group,
Flags,
TransportName,
HelperDllContext,
Events
);
}
/* Return the Helper Pointers */
*HelperDllData = HelperData;
return NO_ERROR;
}
}
/* Get the Transports available */
Status = SockLoadTransportList(&Transports);
/* Check for error */
if (Status) {
WARN("Can't get transport list\n");
return Status;
}
/* Loop through each transport until we find one that can satisfy us */
for (Transport = Transports;
*Transports != 0;
Transport += wcslen(Transport) + 1) {
TRACE("Transport: %S\n", Transports);
/* See what mapping this Transport supports */
Status = SockLoadTransportMapping(Transport, &Mapping);
/* Check for error */
if (Status) {
ERR("Can't get mapping for %S\n", Transports);
HeapFree(GlobalHeap, 0, Transports);
return Status;
}
/* See if this Mapping works for us */
if (SockIsTripleInMapping(Mapping, *AddressFamily, *SocketType, *Protocol)) {
/* It does, so load the DLL associated with it */
Status = SockLoadHelperDll(Transport, Mapping, &HelperData);
/* Check for error */
if (Status) {
ERR("Can't load helper DLL for Transport %S.\n", Transport);
HeapFree(GlobalHeap, 0, Transports);
HeapFree(GlobalHeap, 0, Mapping);
return Status;
}
/* Call the Helper Dll function get the Transport Name */
if (HelperData->WSHOpenSocket2 == NULL) {
/* DLL Doesn't support WSHOpenSocket2, call the old one */
HelperData->WSHOpenSocket(AddressFamily,
SocketType,
Protocol,
TransportName,
HelperDllContext,
Events
);
} else {
HelperData->WSHOpenSocket2(AddressFamily,
SocketType,
Protocol,
Group,
Flags,
TransportName,
HelperDllContext,
Events
);
}
/* Return the Helper Pointers */
*HelperDllData = HelperData;
/* We actually cache these ... the can't be freed yet */
/*HeapFree(GlobalHeap, 0, Transports);*/
/*HeapFree(GlobalHeap, 0, Mapping);*/
return NO_ERROR;
}
HeapFree(GlobalHeap, 0, Mapping);
}
HeapFree(GlobalHeap, 0, Transports);
return WSAEINVAL;
}
INT
SockLoadTransportMapping(
PWSTR TransportName,
PWINSOCK_MAPPING *Mapping)
{
PWSTR TransportKey;
HKEY KeyHandle;
ULONG MappingSize;
LONG Status;
TRACE("TransportName %ws\n", TransportName);
/* Allocate a Buffer */
TransportKey = HeapAlloc(GlobalHeap, 0, (54 + wcslen(TransportName)) * sizeof(WCHAR));
/* Check for error */
if (TransportKey == NULL) {
ERR("Buffer allocation failed\n");
return WSAEINVAL;
}
/* Generate the right key name */
wcscpy(TransportKey, L"System\\CurrentControlSet\\Services\\");
wcscat(TransportKey, TransportName);
wcscat(TransportKey, L"\\Parameters\\Winsock");
/* Open the Key */
Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, TransportKey, 0, KEY_READ, &KeyHandle);
/* We don't need the Transport Key anymore */
HeapFree(GlobalHeap, 0, TransportKey);
/* Check for error */
if (Status) {
ERR("Error reading transport mapping registry\n");
return WSAEINVAL;
}
/* Find out how much space we need for the Mapping */
Status = RegQueryValueExW(KeyHandle, L"Mapping", NULL, NULL, NULL, &MappingSize);
/* Check for error */
if (Status) {
ERR("Error reading transport mapping registry\n");
return WSAEINVAL;
}
/* Allocate Memory for the Mapping */
*Mapping = HeapAlloc(GlobalHeap, 0, MappingSize);
/* Check for error */
if (*Mapping == NULL) {
ERR("Buffer allocation failed\n");
return WSAEINVAL;
}
/* Read the Mapping */
Status = RegQueryValueExW(KeyHandle, L"Mapping", NULL, NULL, (LPBYTE)*Mapping, &MappingSize);
/* Check for error */
if (Status) {
ERR("Error reading transport mapping registry\n");
HeapFree(GlobalHeap, 0, *Mapping);
return WSAEINVAL;
}
/* Close key and return */
RegCloseKey(KeyHandle);
return 0;
}
INT
SockLoadTransportList(
PWSTR *TransportList)
{
ULONG TransportListSize;
HKEY KeyHandle;
LONG Status;
TRACE("Called\n");
/* Open the Transports Key */
Status = RegOpenKeyExW (HKEY_LOCAL_MACHINE,
L"SYSTEM\\CurrentControlSet\\Services\\Winsock\\Parameters",
0,
KEY_READ,
&KeyHandle);
/* Check for error */
if (Status) {
ERR("Error reading transport list registry\n");
return WSAEINVAL;
}
/* Get the Transport List Size */
Status = RegQueryValueExW(KeyHandle,
L"Transports",
NULL,
NULL,
NULL,
&TransportListSize);
/* Check for error */
if (Status) {
ERR("Error reading transport list registry\n");
return WSAEINVAL;
}
/* Allocate Memory for the Transport List */
*TransportList = HeapAlloc(GlobalHeap, 0, TransportListSize);
/* Check for error */
if (*TransportList == NULL) {
ERR("Buffer allocation failed\n");
return WSAEINVAL;
}
/* Get the Transports */
Status = RegQueryValueExW (KeyHandle,
L"Transports",
NULL,
NULL,
(LPBYTE)*TransportList,
&TransportListSize);
/* Check for error */
if (Status) {
ERR("Error reading transport list registry\n");
HeapFree(GlobalHeap, 0, *TransportList);
return WSAEINVAL;
}
/* Close key and return */
RegCloseKey(KeyHandle);
return 0;
}
INT
SockLoadHelperDll(
PWSTR TransportName,
PWINSOCK_MAPPING Mapping,
PHELPER_DATA *HelperDllData)
{
PHELPER_DATA HelperData;
PWSTR HelperDllName;
PWSTR FullHelperDllName;
PWSTR HelperKey;
HKEY KeyHandle;
ULONG DataSize;
LONG Status;
/* Allocate space for the Helper Structure and TransportName */
HelperData = HeapAlloc(GlobalHeap, 0, sizeof(*HelperData) + (wcslen(TransportName) + 1) * sizeof(WCHAR));
/* Check for error */
if (HelperData == NULL) {
ERR("Buffer allocation failed\n");
return WSAEINVAL;
}
/* Allocate Space for the Helper DLL Key */
HelperKey = HeapAlloc(GlobalHeap, 0, (54 + wcslen(TransportName)) * sizeof(WCHAR));
/* Check for error */
if (HelperKey == NULL) {
ERR("Buffer allocation failed\n");
HeapFree(GlobalHeap, 0, HelperData);
return WSAEINVAL;
}
/* Generate the right key name */
wcscpy(HelperKey, L"System\\CurrentControlSet\\Services\\");
wcscat(HelperKey, TransportName);
wcscat(HelperKey, L"\\Parameters\\Winsock");
/* Open the Key */
Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, HelperKey, 0, KEY_READ, &KeyHandle);
HeapFree(GlobalHeap, 0, HelperKey);
/* Check for error */
if (Status) {
ERR("Error reading helper DLL parameters\n");
HeapFree(GlobalHeap, 0, HelperData);
return WSAEINVAL;
}
/* Read Size of SockAddr Structures */
DataSize = sizeof(HelperData->MinWSAddressLength);
HelperData->MinWSAddressLength = 16;
RegQueryValueExW (KeyHandle,
L"MinSockaddrLength",
NULL,
NULL,
(LPBYTE)&HelperData->MinWSAddressLength,
&DataSize);
DataSize = sizeof(HelperData->MinWSAddressLength);
HelperData->MaxWSAddressLength = 16;
RegQueryValueExW (KeyHandle,
L"MaxSockaddrLength",
NULL,
NULL,
(LPBYTE)&HelperData->MaxWSAddressLength,
&DataSize);
/* Size of TDI Structures */
HelperData->MinTDIAddressLength = HelperData->MinWSAddressLength + 6;
HelperData->MaxTDIAddressLength = HelperData->MaxWSAddressLength + 6;
/* Read Delayed Acceptance Setting */
DataSize = sizeof(DWORD);
HelperData->UseDelayedAcceptance = FALSE;
RegQueryValueExW (KeyHandle,
L"UseDelayedAcceptance",
NULL,
NULL,
(LPBYTE)&HelperData->UseDelayedAcceptance,
&DataSize);
/* Allocate Space for the Helper DLL Names */
HelperDllName = HeapAlloc(GlobalHeap, 0, 512);
/* Check for error */
if (HelperDllName == NULL) {
ERR("Buffer allocation failed\n");
HeapFree(GlobalHeap, 0, HelperData);
return WSAEINVAL;
}
FullHelperDllName = HeapAlloc(GlobalHeap, 0, 512);
/* Check for error */
if (FullHelperDllName == NULL) {
ERR("Buffer allocation failed\n");
HeapFree(GlobalHeap, 0, HelperDllName);
HeapFree(GlobalHeap, 0, HelperData);
return WSAEINVAL;
}
/* Get the name of the Helper DLL*/
DataSize = 512;
Status = RegQueryValueExW (KeyHandle,
L"HelperDllName",
NULL,
NULL,
(LPBYTE)HelperDllName,
&DataSize);
/* Check for error */
if (Status) {
ERR("Error reading helper DLL parameters\n");
HeapFree(GlobalHeap, 0, FullHelperDllName);
HeapFree(GlobalHeap, 0, HelperDllName);
HeapFree(GlobalHeap, 0, HelperData);
return WSAEINVAL;
}
/* Get the Full name, expanding Environment Strings */
ExpandEnvironmentStringsW (HelperDllName,
FullHelperDllName,
256);
/* Load the DLL */
HelperData->hInstance = LoadLibraryW(FullHelperDllName);
HeapFree(GlobalHeap, 0, HelperDllName);
HeapFree(GlobalHeap, 0, FullHelperDllName);
if (HelperData->hInstance == NULL) {
ERR("Error loading helper DLL\n");
HeapFree(GlobalHeap, 0, HelperData);
return WSAEINVAL;
}
/* Close Key */
RegCloseKey(KeyHandle);
/* Get the Pointers to the Helper Routines */
HelperData->WSHOpenSocket = (PWSH_OPEN_SOCKET)
GetProcAddress(HelperData->hInstance,
"WSHOpenSocket");
HelperData->WSHOpenSocket2 = (PWSH_OPEN_SOCKET2)
GetProcAddress(HelperData->hInstance,
"WSHOpenSocket2");
HelperData->WSHJoinLeaf = (PWSH_JOIN_LEAF)
GetProcAddress(HelperData->hInstance,
"WSHJoinLeaf");
HelperData->WSHNotify = (PWSH_NOTIFY)
GetProcAddress(HelperData->hInstance, "WSHNotify");
HelperData->WSHGetSocketInformation = (PWSH_GET_SOCKET_INFORMATION)
GetProcAddress(HelperData->hInstance,
"WSHGetSocketInformation");
HelperData->WSHSetSocketInformation = (PWSH_SET_SOCKET_INFORMATION)
GetProcAddress(HelperData->hInstance,
"WSHSetSocketInformation");
HelperData->WSHGetSockaddrType = (PWSH_GET_SOCKADDR_TYPE)
GetProcAddress(HelperData->hInstance,
"WSHGetSockaddrType");
HelperData->WSHGetWildcardSockaddr = (PWSH_GET_WILDCARD_SOCKADDR)
GetProcAddress(HelperData->hInstance,
"WSHGetWildcardSockaddr");
HelperData->WSHGetBroadcastSockaddr = (PWSH_GET_BROADCAST_SOCKADDR)
GetProcAddress(HelperData->hInstance,
"WSHGetBroadcastSockaddr");
HelperData->WSHAddressToString = (PWSH_ADDRESS_TO_STRING)
GetProcAddress(HelperData->hInstance,
"WSHAddressToString");
HelperData->WSHStringToAddress = (PWSH_STRING_TO_ADDRESS)
GetProcAddress(HelperData->hInstance,
"WSHStringToAddress");
HelperData->WSHIoctl = (PWSH_IOCTL)
GetProcAddress(HelperData->hInstance,
"WSHIoctl");
/* Save the Mapping Structure and transport name */
HelperData->Mapping = Mapping;
wcscpy(HelperData->TransportName, TransportName);
/* Increment Reference Count */
HelperData->RefCount = 1;
/* Add it to our list */
InsertHeadList(&SockHelpersListHead, &HelperData->Helpers);
/* Return Pointers */
*HelperDllData = HelperData;
return 0;
}
BOOL
SockIsTripleInMapping(
PWINSOCK_MAPPING Mapping,
INT AddressFamily,
INT SocketType,
INT Protocol)
{
/* The Windows version returns more detailed information on which of the 3 parameters failed...we should do this later */
ULONG Row;
TRACE("Called, Mapping rows = %d\n", Mapping->Rows);
/* Loop through Mapping to Find a matching one */
for (Row = 0; Row < Mapping->Rows; Row++) {
TRACE("Examining: row %d: AF %d type %d proto %d\n",
Row,
(INT)Mapping->Mapping[Row].AddressFamily,
(INT)Mapping->Mapping[Row].SocketType,
(INT)Mapping->Mapping[Row].Protocol);
/* Check of all three values Match */
if (((INT)Mapping->Mapping[Row].AddressFamily == AddressFamily) &&
((INT)Mapping->Mapping[Row].SocketType == SocketType) &&
((INT)Mapping->Mapping[Row].Protocol == Protocol)) {
TRACE("Found\n");
return TRUE;
}
}
WARN("Not found\n");
return FALSE;
}
/* EOF */