[CMD] ASSOC: Simplify the code and make it more robust; fix returned ERRORLEVEL values.

- Make sure that non-administrator users can list associations, and
  display appropriate error messages when e.g. they don't have sufficient
  privileges to perform an operation.

- Make the helper functions all return Win32 values, used as the
  ERRORVALUE, except when a specific extension association fails to be
  displayed, in which case the ERRORVALUE is normalized to 1.

- Since the 'param' is a modifiable string (that can be modified by the
  command, independently of the way it's called), just use it to isolate
  the extension by zeroing out the equls-sign separator.
This commit is contained in:
Hermès Bélusca-Maïto 2020-07-28 01:10:58 +02:00
parent 63316df520
commit 141378cfc8
No known key found for this signature in database
GPG key ID: 3B2539C65E7B93D0

View file

@ -13,187 +13,232 @@
* *
* TODO: * TODO:
* - PrintAllAssociations could be optimized to not fetch all registry subkeys under 'Classes', just the ones that start with '.' * - PrintAllAssociations could be optimized to not fetch all registry subkeys under 'Classes', just the ones that start with '.'
* - Make sure that non-administrator users can list associations, and get appropriate error messages when they don't have sufficient
* privileges to perform an operation.
*/ */
#include "precomp.h" #include "precomp.h"
#ifdef INCLUDE_CMD_ASSOC #ifdef INCLUDE_CMD_ASSOC
static INT static LONG
PrintAssociation( PrintAssociationEx(
IN LPCTSTR extension) IN HKEY hKeyClasses,
IN PCTSTR pszExtension)
{ {
DWORD lRet; LONG lRet;
HKEY hKey = NULL, hSubKey = NULL; HKEY hKey;
DWORD fileTypeLength = 0; DWORD dwFileTypeLen = 0;
LPTSTR fileType = NULL; PTSTR pszFileType;
lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_READ, &hKey); lRet = RegOpenKeyEx(hKeyClasses, pszExtension, 0, KEY_QUERY_VALUE, &hKey);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
return -1; {
if (lRet != ERROR_FILE_NOT_FOUND)
ErrorMessage(lRet, NULL);
return lRet;
}
lRet = RegOpenKeyEx(hKey, extension, 0, KEY_READ, &hSubKey); /* Obtain the string length */
lRet = RegQueryValueEx(hKey, NULL, NULL, NULL, NULL, &dwFileTypeLen);
/* If there is no default value, don't display it */
if (lRet == ERROR_FILE_NOT_FOUND)
{
RegCloseKey(hKey);
return lRet;
}
if (lRet != ERROR_SUCCESS)
{
ErrorMessage(lRet, NULL);
RegCloseKey(hKey);
return lRet;
}
++dwFileTypeLen;
pszFileType = cmd_alloc(dwFileTypeLen * sizeof(TCHAR));
if (!pszFileType)
{
WARN("Cannot allocate memory for pszFileType!\n");
RegCloseKey(hKey);
return ERROR_NOT_ENOUGH_MEMORY;
}
/* Obtain the actual file type */
lRet = RegQueryValueEx(hKey, NULL, NULL, NULL, (LPBYTE)pszFileType, &dwFileTypeLen);
RegCloseKey(hKey); RegCloseKey(hKey);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
return 0;
/* Obtain string length */
lRet = RegQueryValueEx(hSubKey, NULL, NULL, NULL, NULL, &fileTypeLength);
/* If there is no default value, don't display */
if (lRet == ERROR_FILE_NOT_FOUND)
{ {
RegCloseKey(hSubKey); ErrorMessage(lRet, NULL);
return 0; cmd_free(pszFileType);
} return lRet;
if (lRet != ERROR_SUCCESS)
{
RegCloseKey(hSubKey);
return -2;
} }
fileType = cmd_alloc(fileTypeLength * sizeof(TCHAR)); /* If there is a default key, display the relevant information */
if (!fileType) if (dwFileTypeLen != 0)
{ {
WARN("Cannot allocate memory for fileType!\n"); ConOutPrintf(_T("%s=%s\n"), pszExtension, pszFileType);
RegCloseKey(hSubKey);
return -2;
} }
/* Obtain actual file type */ cmd_free(pszFileType);
lRet = RegQueryValueEx(hSubKey, NULL, NULL, NULL, (LPBYTE)fileType, &fileTypeLength); return ERROR_SUCCESS;
RegCloseKey(hSubKey);
if (lRet != ERROR_SUCCESS)
{
cmd_free(fileType);
return -2;
}
/* If there is a default key, display relevant information */
if (fileTypeLength != 0)
{
ConOutPrintf(_T("%s=%s\n"), extension, fileType);
}
cmd_free(fileType);
return 1;
} }
static INT static LONG
PrintAssociation(
IN PCTSTR pszExtension)
{
LONG lRet;
HKEY hKeyClasses;
lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
KEY_ENUMERATE_SUB_KEYS, &hKeyClasses);
if (lRet != ERROR_SUCCESS)
{
ErrorMessage(lRet, NULL);
return lRet;
}
lRet = PrintAssociationEx(hKeyClasses, pszExtension);
RegCloseKey(hKeyClasses);
return lRet;
}
static LONG
PrintAllAssociations(VOID) PrintAllAssociations(VOID)
{ {
DWORD lRet = 0; LONG lRet;
HKEY hKey = NULL; HKEY hKeyClasses;
DWORD numKeys = 0; DWORD dwKeyCtr;
DWORD dwNumKeys = 0;
DWORD dwExtLen = 0;
PTSTR pszExtName;
DWORD extLength = 0; lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
LPTSTR extName = NULL; KEY_QUERY_VALUE | KEY_ENUMERATE_SUB_KEYS, &hKeyClasses);
DWORD keyCtr = 0;
lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_READ, &hKey);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
return -1; {
ErrorMessage(lRet, NULL);
return lRet;
}
lRet = RegQueryInfoKey(hKey, NULL, NULL, NULL, &numKeys, &extLength, lRet = RegQueryInfoKey(hKeyClasses, NULL, NULL, NULL, &dwNumKeys, &dwExtLen,
NULL, NULL, NULL, NULL, NULL, NULL); NULL, NULL, NULL, NULL, NULL, NULL);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
{ {
RegCloseKey(hKey); ErrorMessage(lRet, NULL);
return -2; RegCloseKey(hKeyClasses);
return lRet;
} }
extLength++; ++dwExtLen;
extName = cmd_alloc(extLength * sizeof(TCHAR)); pszExtName = cmd_alloc(dwExtLen * sizeof(TCHAR));
if (!extName) if (!pszExtName)
{ {
WARN("Cannot allocate memory for extName!\n"); WARN("Cannot allocate memory for pszExtName!\n");
RegCloseKey(hKey); RegCloseKey(hKeyClasses);
return -2; return ERROR_NOT_ENOUGH_MEMORY;
} }
for (keyCtr = 0; keyCtr < numKeys; ++keyCtr) for (dwKeyCtr = 0; dwKeyCtr < dwNumKeys; ++dwKeyCtr)
{ {
DWORD dwBufSize = extLength; DWORD dwBufSize = dwExtLen;
lRet = RegEnumKeyEx(hKey, keyCtr, extName, &dwBufSize, lRet = RegEnumKeyEx(hKeyClasses, dwKeyCtr, pszExtName, &dwBufSize,
NULL, NULL, NULL, NULL); NULL, NULL, NULL, NULL);
if (lRet == ERROR_SUCCESS || lRet == ERROR_MORE_DATA) if (lRet == ERROR_SUCCESS || lRet == ERROR_MORE_DATA)
{ {
if (*extName == _T('.')) /* Name starts with '.': this is an extension */
PrintAssociation(extName); if (*pszExtName == _T('.'))
PrintAssociationEx(hKeyClasses, pszExtName);
} }
else else
{ {
cmd_free(extName); ErrorMessage(lRet, NULL);
RegCloseKey(hKey); cmd_free(pszExtName);
return -1; RegCloseKey(hKeyClasses);
return lRet;
} }
} }
RegCloseKey(hKey); RegCloseKey(hKeyClasses);
cmd_free(extName); cmd_free(pszExtName);
return numKeys; return ERROR_SUCCESS;
} }
static INT static LONG
AddAssociation( AddAssociation(
IN LPCTSTR extension, IN PCTSTR pszExtension,
IN LPCTSTR type) IN PCTSTR pszType)
{ {
DWORD lRet; LONG lRet;
HKEY hKey = NULL, hSubKey = NULL; HKEY hKeyClasses, hKey;
lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_ALL_ACCESS, &hKey); lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
KEY_CREATE_SUB_KEY, &hKeyClasses);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
return -1; {
ErrorMessage(lRet, NULL);
return lRet;
}
lRet = RegCreateKeyEx(hKey, extension, 0, NULL, REG_OPTION_NON_VOLATILE, lRet = RegCreateKeyEx(hKeyClasses, pszExtension, 0, NULL, REG_OPTION_NON_VOLATILE,
KEY_ALL_ACCESS, NULL, &hSubKey, NULL); KEY_SET_VALUE, NULL, &hKey, NULL);
RegCloseKey(hKeyClasses);
if (lRet != ERROR_SUCCESS)
{
ErrorMessage(lRet, NULL);
return lRet;
}
lRet = RegSetValueEx(hKey, NULL, 0, REG_SZ,
(LPBYTE)pszType, (DWORD)(_tcslen(pszType) + 1) * sizeof(TCHAR));
RegCloseKey(hKey); RegCloseKey(hKey);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
return -1; {
ErrorMessage(lRet, NULL);
return lRet;
}
lRet = RegSetValueEx(hSubKey, NULL, 0, REG_SZ, return ERROR_SUCCESS;
(LPBYTE)type, (_tcslen(type) + 1) * sizeof(TCHAR));
RegCloseKey(hSubKey);
if (lRet != ERROR_SUCCESS)
return -2;
return 0;
} }
static INT static LONG
RemoveAssociation( RemoveAssociation(
IN LPCTSTR extension) IN PCTSTR pszExtension)
{ {
DWORD lRet; LONG lRet;
HKEY hKey; HKEY hKeyClasses;
lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0, KEY_ALL_ACCESS, &hKey); lRet = RegOpenKeyEx(HKEY_LOCAL_MACHINE, _T("SOFTWARE\\Classes"), 0,
KEY_QUERY_VALUE, &hKeyClasses);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
return -1; {
ErrorMessage(lRet, NULL);
return lRet;
}
lRet = RegDeleteKey(hKey, extension); lRet = RegDeleteKey(hKeyClasses, pszExtension);
RegCloseKey(hKey); RegCloseKey(hKeyClasses);
if (lRet != ERROR_SUCCESS) if (lRet != ERROR_SUCCESS)
return -2; {
if (lRet != ERROR_FILE_NOT_FOUND)
ErrorMessage(lRet, NULL);
return lRet;
}
return 0; return ERROR_SUCCESS;
} }
INT CommandAssoc(LPTSTR param) INT CommandAssoc(LPTSTR param)
{ {
INT retval = 0; INT retval = 0;
LPTSTR lpEqualSign; PTCHAR pEqualSign;
/* Print help */ /* Print help */
if (!_tcsncmp(param, _T("/?"), 2)) if (!_tcsncmp(param, _T("/?"), 2))
@ -202,53 +247,51 @@ INT CommandAssoc(LPTSTR param)
return 0; return 0;
} }
if (_tcslen(param) == 0) /* Print all associations if no parameter has been specified */
if (!*param)
{ {
PrintAllAssociations(); PrintAllAssociations();
goto Quit; goto Quit;
} }
lpEqualSign = _tcschr(param, _T('=')); pEqualSign = _tcschr(param, _T('='));
if (lpEqualSign != NULL) if (pEqualSign != NULL)
{ {
LPTSTR fileType = lpEqualSign + 1; PTSTR pszFileType = pEqualSign + 1;
LPTSTR extension = cmd_alloc((lpEqualSign - param + 1) * sizeof(TCHAR));
if (!extension)
{
WARN("Cannot allocate memory for extension!\n");
error_out_of_memory();
retval = 1;
goto Quit;
}
_tcsncpy(extension, param, lpEqualSign - param); /* NULL-terminate at the equals sign */
extension[lpEqualSign - param] = _T('\0'); *pEqualSign = 0;
/* If the equal sign is the last character /* If the equals sign is the last character
* in the string, then delete the key. */ * in the string, delete the association. */
if (_tcslen(fileType) == 0) if (*pszFileType == 0)
{ {
retval = RemoveAssociation(extension); retval = RemoveAssociation(param);
} }
else else
/* Otherwise, add the key and print out the association */ /* Otherwise, add the association and print it out */
{ {
retval = AddAssociation(extension, fileType); retval = AddAssociation(param, pszFileType);
PrintAssociation(extension); PrintAssociation(param);
} }
cmd_free(extension); if (retval != ERROR_SUCCESS)
{
if (retval) if (retval != ERROR_FILE_NOT_FOUND)
retval = 1; /* Fixup the error value */ {
// FIXME: Localize
ConErrPrintf(_T("Error occurred while processing: %s.\n"), param);
}
// retval = 1; /* Fixup the error value */
}
} }
else else
{ {
/* No equal sign, print all associations */ /* No equals sign, print the association */
retval = PrintAssociation(param); retval = PrintAssociation(param);
if (retval == 0) /* If nothing printed out */ if (retval != ERROR_SUCCESS)
{ {
ConOutResPrintf(STRING_ASSOC_ERROR, param); ConErrResPrintf(STRING_ASSOC_ERROR, param);
retval = 1; /* Fixup the error value */ retval = 1; /* Fixup the error value */
} }
} }