[WIN32K:NTUSER] Avoid TOCTOU in ProbeAndCaptureUnicodeStringOrAtom.

This commit is contained in:
Thomas Faber 2023-09-09 08:59:57 -04:00
parent ce43bf6ba7
commit f9212e4a72
No known key found for this signature in database
GPG key ID: 076E7C3D44720826

View file

@ -152,32 +152,35 @@ ProbeAndCaptureUnicodeStringOrAtom(
__in_data_source(USER_MODE) _In_ PUNICODE_STRING pustrUnsafe) __in_data_source(USER_MODE) _In_ PUNICODE_STRING pustrUnsafe)
{ {
NTSTATUS Status = STATUS_SUCCESS; NTSTATUS Status = STATUS_SUCCESS;
UNICODE_STRING ustrCopy;
/* Default to NULL */ /* Default to NULL */
pustrOut->Buffer = NULL; RtlInitEmptyUnicodeString(pustrOut, NULL, 0);
_SEH2_TRY _SEH2_TRY
{ {
ProbeForRead(pustrUnsafe, sizeof(UNICODE_STRING), 1); ProbeForRead(pustrUnsafe, sizeof(UNICODE_STRING), 1);
ustrCopy = *pustrUnsafe;
/* Validate the string */ /* Validate the string */
if ((pustrUnsafe->Length & 1) || (pustrUnsafe->Buffer == NULL)) if ((ustrCopy.Length & 1) || (ustrCopy.Buffer == NULL))
{ {
/* This is not legal */ /* This is not legal */
_SEH2_YIELD(return STATUS_INVALID_PARAMETER); _SEH2_YIELD(return STATUS_INVALID_PARAMETER);
} }
/* Check if this is an atom */ /* Check if this is an atom */
if (IS_ATOM(pustrUnsafe->Buffer)) if (IS_ATOM(ustrCopy.Buffer))
{ {
/* Copy the atom, length is 0 */ /* Copy the atom, length is 0 */
pustrOut->MaximumLength = pustrOut->Length = 0; pustrOut->MaximumLength = pustrOut->Length = 0;
pustrOut->Buffer = pustrUnsafe->Buffer; pustrOut->Buffer = ustrCopy.Buffer;
} }
else else
{ {
/* Get the length, maximum length includes zero termination */ /* Get the length, maximum length includes zero termination */
pustrOut->Length = pustrUnsafe->Length; pustrOut->Length = ustrCopy.Length;
pustrOut->MaximumLength = pustrOut->Length + sizeof(WCHAR); pustrOut->MaximumLength = pustrOut->Length + sizeof(WCHAR);
/* Allocate a buffer */ /* Allocate a buffer */
@ -190,8 +193,8 @@ ProbeAndCaptureUnicodeStringOrAtom(
} }
/* Copy the string and zero terminate it */ /* Copy the string and zero terminate it */
ProbeForRead(pustrUnsafe->Buffer, pustrOut->Length, 1); ProbeForRead(ustrCopy.Buffer, pustrOut->Length, 1);
RtlCopyMemory(pustrOut->Buffer, pustrUnsafe->Buffer, pustrOut->Length); RtlCopyMemory(pustrOut->Buffer, ustrCopy.Buffer, pustrOut->Length);
pustrOut->Buffer[pustrOut->Length / sizeof(WCHAR)] = L'\0'; pustrOut->Buffer[pustrOut->Length / sizeof(WCHAR)] = L'\0';
} }
} }