diff --git a/reactos/lib/msafd/misc/dllmain.c b/reactos/lib/msafd/misc/dllmain.c index 9e1a6368b99..e32739700e5 100644 --- a/reactos/lib/msafd/misc/dllmain.c +++ b/reactos/lib/msafd/misc/dllmain.c @@ -198,15 +198,30 @@ WSPSocket( SocketType = lpProtocolInfo->iSocketType; Protocol = lpProtocolInfo->iProtocol; - Status = HelperDLL->EntryTable.lpWSHOpenSocket2( - &AddressFamily, - &SocketType, - &Protocol, - 0, - 0, - &TdiDeviceName, - &HelperContext, - &NotificationEvents); + /* The OPTIONAL export WSHOpenSocket2 supersedes WSHOpenSocket */ + if (HelperDLL->EntryTable.lpWSHOpenSocket2) + { + Status = HelperDLL->EntryTable.lpWSHOpenSocket2( + &AddressFamily, + &SocketType, + &Protocol, + 0, + 0, + &TdiDeviceName, + &HelperContext, + &NotificationEvents); + } + else + { + Status = HelperDLL->EntryTable.lpWSHOpenSocket( + &AddressFamily, + &SocketType, + &Protocol, + &TdiDeviceName, + &HelperContext, + &NotificationEvents); + } + if (Status != NO_ERROR) { AFD_DbgPrint(MAX_TRACE, ("WinSock Helper DLL failed (0x%X).\n", Status)); *lpErrno = Status; diff --git a/reactos/lib/msafd/misc/helpers.c b/reactos/lib/msafd/misc/helpers.c index 5d92e1613e9..26e665a2228 100644 --- a/reactos/lib/msafd/misc/helpers.c +++ b/reactos/lib/msafd/misc/helpers.c @@ -104,69 +104,67 @@ INT GetHelperDLLEntries( PWSHELPER_DLL HelperDLL) { PVOID e; - - e = GetProcAddress(HelperDLL->hModule, "WSHAddressToString"); - if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHAddressToString) = e; - + + /* The following functions MUST be supported */ e = GetProcAddress(HelperDLL->hModule, "WSHEnumProtocols"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHEnumProtocols) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHEnumProtocols) = e; - e = GetProcAddress(HelperDLL->hModule, "WSHGetBroadcastSockaddr"); + e = GetProcAddress(HelperDLL->hModule, "WSHGetSockaddrType"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHGetBroadcastSockaddr) = e; - - e = GetProcAddress(HelperDLL->hModule, "WSHGetProviderGuid"); - if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHGetProviderGuid) = e; - - e = GetProcAddress(HelperDLL->hModule, "WSHGetSockaddrType"); - if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHGetSockaddrType) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHGetSockaddrType) = e; e = GetProcAddress(HelperDLL->hModule, "WSHGetSocketInformation"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHGetSocketInformation) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHGetSocketInformation) = e; e = GetProcAddress(HelperDLL->hModule, "WSHGetWildcardSockaddr"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHGetWildcardSockaddr) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHGetWildcardSockaddr) = e; e = GetProcAddress(HelperDLL->hModule, "WSHGetWinsockMapping"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHGetWinsockMapping) = e; - - e = GetProcAddress(HelperDLL->hModule, "WSHGetWSAProtocolInfo"); - if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHGetWSAProtocolInfo) = e; - - e = GetProcAddress(HelperDLL->hModule, "WSHIoctl"); - if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHIoctl) = e; - - e = GetProcAddress(HelperDLL->hModule, "WSHJoinLeaf"); - if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHJoinLeaf) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHGetWinsockMapping) = e; e = GetProcAddress(HelperDLL->hModule, "WSHNotify"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHNotify) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHNotify) = e; e = GetProcAddress(HelperDLL->hModule, "WSHOpenSocket"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHOpenSocket) = e; - - e = GetProcAddress(HelperDLL->hModule, "WSHOpenSocket2"); - if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHOpenSocket2) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHOpenSocket) = e; e = GetProcAddress(HelperDLL->hModule, "WSHSetSocketInformation"); if (!e) return ERROR_BAD_PROVIDER; - ((PVOID) HelperDLL->EntryTable.lpWSHSetSocketInformation) = e; + ((PVOID) HelperDLL->EntryTable.lpWSHSetSocketInformation) = e; + + + /* + The following functions are OPTIONAL. + Whoever wants to call them, must check that the pointer is not NULL. + */ + e = GetProcAddress(HelperDLL->hModule, "WSHAddressToString"); + ((PVOID) HelperDLL->EntryTable.lpWSHAddressToString) = e; + + e = GetProcAddress(HelperDLL->hModule, "WSHGetBroadcastSockaddr"); + ((PVOID) HelperDLL->EntryTable.lpWSHGetBroadcastSockaddr) = e; + + e = GetProcAddress(HelperDLL->hModule, "WSHGetProviderGuid"); + ((PVOID) HelperDLL->EntryTable.lpWSHGetProviderGuid) = e; + + e = GetProcAddress(HelperDLL->hModule, "WSHGetWSAProtocolInfo"); + ((PVOID) HelperDLL->EntryTable.lpWSHGetWSAProtocolInfo) = e; + + e = GetProcAddress(HelperDLL->hModule, "WSHIoctl"); + ((PVOID) HelperDLL->EntryTable.lpWSHIoctl) = e; + + e = GetProcAddress(HelperDLL->hModule, "WSHJoinLeaf"); + ((PVOID) HelperDLL->EntryTable.lpWSHJoinLeaf) = e; + + e = GetProcAddress(HelperDLL->hModule, "WSHOpenSocket2"); + ((PVOID) HelperDLL->EntryTable.lpWSHOpenSocket2) = e; e = GetProcAddress(HelperDLL->hModule, "WSHStringToAddress"); - if (!e) return ERROR_BAD_PROVIDER; ((PVOID) HelperDLL->EntryTable.lpWSHStringToAddress) = e; return NO_ERROR;