From 141378cfc8c7964516c1f7383110e440aac1f84d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herm=C3=A8s=20B=C3=A9lusca-Ma=C3=AFto?= Date: Tue, 28 Jul 2020 01:10:58 +0200 Subject: [PATCH] [CMD] ASSOC: Simplify the code and make it more robust; fix returned ERRORLEVEL values. - Make sure that non-administrator users can list associations, and display appropriate error messages when e.g. they don't have sufficient privileges to perform an operation. - Make the helper functions all return Win32 values, used as the ERRORVALUE, except when a specific extension association fails to be displayed, in which case the ERRORVALUE is normalized to 1. - Since the 'param' is a modifiable string (that can be modified by the command, independently of the way it's called), just use it to isolate the extension by zeroing out the equls-sign separator. --- base/shell/cmd/assoc.c | 315 +++++++++++++++++++++++------------------ 1 file changed, 179 insertions(+), 136 deletions(-) diff --git a/base/shell/cmd/assoc.c b/base/shell/cmd/assoc.c index 3e5611e6b1f..6441d3d369e 100644 --- a/base/shell/cmd/assoc.c +++ b/base/shell/cmd/assoc.c @@ -13,187 +13,232 @@ * * TODO: * - PrintAllAssociations could be optimized to not fetch all registry subkeys under 'Classes', just the ones that start with '.' - * - Make sure that non-administrator users can list associations, and get appropriate error messages when they don't have sufficient - * privileges to perform an operation. */ #include "precomp.h" #ifdef INCLUDE_CMD_ASSOC -static INT -PrintAssociation( - IN LPCTSTR extension) +static LONG +PrintAssociationEx( + IN HKEY hKeyClasses, + IN PCTSTR pszExtension) { - DWORD lRet; - HKEY hKey = NULL, hSubKey = NULL; - DWORD fileTypeLength = 0; - LPTSTR fileType = NULL; + LONG lRet; + HKEY hKey; + DWORD dwFileTypeLen = 0; + PTSTR pszFileType; - lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_READ, &hKey); + lRet = RegOpenKeyEx(hKeyClasses, pszExtension, 0, KEY_QUERY_VALUE, &hKey); if (lRet != ERROR_SUCCESS) - return -1; + { + if (lRet != ERROR_FILE_NOT_FOUND) + ErrorMessage(lRet, NULL); + return lRet; + } - lRet = RegOpenKeyEx(hKey, extension, 0, KEY_READ, &hSubKey); + /* Obtain the string length */ + lRet = RegQueryValueEx(hKey, NULL, NULL, NULL, NULL, &dwFileTypeLen); + + /* If there is no default value, don't display it */ + if (lRet == ERROR_FILE_NOT_FOUND) + { + RegCloseKey(hKey); + return lRet; + } + if (lRet != ERROR_SUCCESS) + { + ErrorMessage(lRet, NULL); + RegCloseKey(hKey); + return lRet; + } + + ++dwFileTypeLen; + pszFileType = cmd_alloc(dwFileTypeLen * sizeof(TCHAR)); + if (!pszFileType) + { + WARN("Cannot allocate memory for pszFileType!\n"); + RegCloseKey(hKey); + return ERROR_NOT_ENOUGH_MEMORY; + } + + /* Obtain the actual file type */ + lRet = RegQueryValueEx(hKey, NULL, NULL, NULL, (LPBYTE)pszFileType, &dwFileTypeLen); RegCloseKey(hKey); if (lRet != ERROR_SUCCESS) - return 0; - - /* Obtain string length */ - lRet = RegQueryValueEx(hSubKey, NULL, NULL, NULL, NULL, &fileTypeLength); - - /* If there is no default value, don't display */ - if (lRet == ERROR_FILE_NOT_FOUND) { - RegCloseKey(hSubKey); - return 0; - } - if (lRet != ERROR_SUCCESS) - { - RegCloseKey(hSubKey); - return -2; + ErrorMessage(lRet, NULL); + cmd_free(pszFileType); + return lRet; } - fileType = cmd_alloc(fileTypeLength * sizeof(TCHAR)); - if (!fileType) + /* If there is a default key, display the relevant information */ + if (dwFileTypeLen != 0) { - WARN("Cannot allocate memory for fileType!\n"); - RegCloseKey(hSubKey); - return -2; + ConOutPrintf(_T("%s=%s\n"), pszExtension, pszFileType); } - /* Obtain actual file type */ - lRet = RegQueryValueEx(hSubKey, NULL, NULL, NULL, (LPBYTE)fileType, &fileTypeLength); - RegCloseKey(hSubKey); - - if (lRet != ERROR_SUCCESS) - { - cmd_free(fileType); - return -2; - } - - /* If there is a default key, display relevant information */ - if (fileTypeLength != 0) - { - ConOutPrintf(_T("%s=%s\n"), extension, fileType); - } - - cmd_free(fileType); - return 1; + cmd_free(pszFileType); + return ERROR_SUCCESS; } -static INT +static LONG +PrintAssociation( + IN PCTSTR pszExtension) +{ + LONG lRet; + HKEY hKeyClasses; + + lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, + KEY_ENUMERATE_SUB_KEYS, &hKeyClasses); + if (lRet != ERROR_SUCCESS) + { + ErrorMessage(lRet, NULL); + return lRet; + } + + lRet = PrintAssociationEx(hKeyClasses, pszExtension); + + RegCloseKey(hKeyClasses); + return lRet; +} + +static LONG PrintAllAssociations(VOID) { - DWORD lRet = 0; - HKEY hKey = NULL; - DWORD numKeys = 0; + LONG lRet; + HKEY hKeyClasses; + DWORD dwKeyCtr; + DWORD dwNumKeys = 0; + DWORD dwExtLen = 0; + PTSTR pszExtName; - DWORD extLength = 0; - LPTSTR extName = NULL; - DWORD keyCtr = 0; - - lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_READ, &hKey); + lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, + KEY_QUERY_VALUE | KEY_ENUMERATE_SUB_KEYS, &hKeyClasses); if (lRet != ERROR_SUCCESS) - return -1; + { + ErrorMessage(lRet, NULL); + return lRet; + } - lRet = RegQueryInfoKey(hKey, NULL, NULL, NULL, &numKeys, &extLength, + lRet = RegQueryInfoKey(hKeyClasses, NULL, NULL, NULL, &dwNumKeys, &dwExtLen, NULL, NULL, NULL, NULL, NULL, NULL); if (lRet != ERROR_SUCCESS) { - RegCloseKey(hKey); - return -2; + ErrorMessage(lRet, NULL); + RegCloseKey(hKeyClasses); + return lRet; } - extLength++; - extName = cmd_alloc(extLength * sizeof(TCHAR)); - if (!extName) + ++dwExtLen; + pszExtName = cmd_alloc(dwExtLen * sizeof(TCHAR)); + if (!pszExtName) { - WARN("Cannot allocate memory for extName!\n"); - RegCloseKey(hKey); - return -2; + WARN("Cannot allocate memory for pszExtName!\n"); + RegCloseKey(hKeyClasses); + return ERROR_NOT_ENOUGH_MEMORY; } - for (keyCtr = 0; keyCtr < numKeys; ++keyCtr) + for (dwKeyCtr = 0; dwKeyCtr < dwNumKeys; ++dwKeyCtr) { - DWORD dwBufSize = extLength; - lRet = RegEnumKeyEx(hKey, keyCtr, extName, &dwBufSize, + DWORD dwBufSize = dwExtLen; + lRet = RegEnumKeyEx(hKeyClasses, dwKeyCtr, pszExtName, &dwBufSize, NULL, NULL, NULL, NULL); if (lRet == ERROR_SUCCESS || lRet == ERROR_MORE_DATA) { - if (*extName == _T('.')) - PrintAssociation(extName); + /* Name starts with '.': this is an extension */ + if (*pszExtName == _T('.')) + PrintAssociationEx(hKeyClasses, pszExtName); } else { - cmd_free(extName); - RegCloseKey(hKey); - return -1; + ErrorMessage(lRet, NULL); + cmd_free(pszExtName); + RegCloseKey(hKeyClasses); + return lRet; } } - RegCloseKey(hKey); + RegCloseKey(hKeyClasses); - cmd_free(extName); - return numKeys; + cmd_free(pszExtName); + return ERROR_SUCCESS; } -static INT +static LONG AddAssociation( - IN LPCTSTR extension, - IN LPCTSTR type) + IN PCTSTR pszExtension, + IN PCTSTR pszType) { - DWORD lRet; - HKEY hKey = NULL, hSubKey = NULL; + LONG lRet; + HKEY hKeyClasses, hKey; - lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_ALL_ACCESS, &hKey); + lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, + KEY_CREATE_SUB_KEY, &hKeyClasses); if (lRet != ERROR_SUCCESS) - return -1; + { + ErrorMessage(lRet, NULL); + return lRet; + } - lRet = RegCreateKeyEx(hKey, extension, 0, NULL, REG_OPTION_NON_VOLATILE, - KEY_ALL_ACCESS, NULL, &hSubKey, NULL); + lRet = RegCreateKeyEx(hKeyClasses, pszExtension, 0, NULL, REG_OPTION_NON_VOLATILE, + KEY_SET_VALUE, NULL, &hKey, NULL); + RegCloseKey(hKeyClasses); + + if (lRet != ERROR_SUCCESS) + { + ErrorMessage(lRet, NULL); + return lRet; + } + + lRet = RegSetValueEx(hKey, NULL, 0, REG_SZ, + (LPBYTE)pszType, (DWORD)(_tcslen(pszType) + 1) * sizeof(TCHAR)); RegCloseKey(hKey); if (lRet != ERROR_SUCCESS) - return -1; + { + ErrorMessage(lRet, NULL); + return lRet; + } - lRet = RegSetValueEx(hSubKey, NULL, 0, REG_SZ, - (LPBYTE)type, (_tcslen(type) + 1) * sizeof(TCHAR)); - RegCloseKey(hSubKey); - - if (lRet != ERROR_SUCCESS) - return -2; - - return 0; + return ERROR_SUCCESS; } -static INT +static LONG RemoveAssociation( - IN LPCTSTR extension) + IN PCTSTR pszExtension) { - DWORD lRet; - HKEY hKey; + LONG lRet; + HKEY hKeyClasses; - lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_ALL_ACCESS, &hKey); + lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, + KEY_QUERY_VALUE, &hKeyClasses); if (lRet != ERROR_SUCCESS) - return -1; + { + ErrorMessage(lRet, NULL); + return lRet; + } - lRet = RegDeleteKey(hKey, extension); - RegCloseKey(hKey); + lRet = RegDeleteKey(hKeyClasses, pszExtension); + RegCloseKey(hKeyClasses); if (lRet != ERROR_SUCCESS) - return -2; + { + if (lRet != ERROR_FILE_NOT_FOUND) + ErrorMessage(lRet, NULL); + return lRet; + } - return 0; + return ERROR_SUCCESS; } INT CommandAssoc(LPTSTR param) { INT retval = 0; - LPTSTR lpEqualSign; + PTCHAR pEqualSign; /* Print help */ if (!_tcsncmp(param, _T("/?"), 2)) @@ -202,53 +247,51 @@ INT CommandAssoc(LPTSTR param) return 0; } - if (_tcslen(param) == 0) + /* Print all associations if no parameter has been specified */ + if (!*param) { PrintAllAssociations(); goto Quit; } - lpEqualSign = _tcschr(param, _T('=')); - if (lpEqualSign != NULL) + pEqualSign = _tcschr(param, _T('=')); + if (pEqualSign != NULL) { - LPTSTR fileType = lpEqualSign + 1; - LPTSTR extension = cmd_alloc((lpEqualSign - param + 1) * sizeof(TCHAR)); - if (!extension) - { - WARN("Cannot allocate memory for extension!\n"); - error_out_of_memory(); - retval = 1; - goto Quit; - } + PTSTR pszFileType = pEqualSign + 1; - _tcsncpy(extension, param, lpEqualSign - param); - extension[lpEqualSign - param] = _T('\0'); + /* NULL-terminate at the equals sign */ + *pEqualSign = 0; - /* If the equal sign is the last character - * in the string, then delete the key. */ - if (_tcslen(fileType) == 0) + /* If the equals sign is the last character + * in the string, delete the association. */ + if (*pszFileType == 0) { - retval = RemoveAssociation(extension); + retval = RemoveAssociation(param); } else - /* Otherwise, add the key and print out the association */ + /* Otherwise, add the association and print it out */ { - retval = AddAssociation(extension, fileType); - PrintAssociation(extension); + retval = AddAssociation(param, pszFileType); + PrintAssociation(param); } - cmd_free(extension); - - if (retval) - retval = 1; /* Fixup the error value */ + if (retval != ERROR_SUCCESS) + { + if (retval != ERROR_FILE_NOT_FOUND) + { + // FIXME: Localize + ConErrPrintf(_T("Error occurred while processing: %s.\n"), param); + } + // retval = 1; /* Fixup the error value */ + } } else { - /* No equal sign, print all associations */ + /* No equals sign, print the association */ retval = PrintAssociation(param); - if (retval == 0) /* If nothing printed out */ + if (retval != ERROR_SUCCESS) { - ConOutResPrintf(STRING_ASSOC_ERROR, param); + ConErrResPrintf(STRING_ASSOC_ERROR, param); retval = 1; /* Fixup the error value */ } }