[NTOS:KE:X64] Improve kernel stack switching on GUI system calls

To be 100% correct and not rely on assumptions, stack switching can only be done when all previous code - starting with the syscall entry point - is pure asm code, since we can't rely on the C compiler to not use stack addresses in a way that is not transparent. Therefore the new code uses the same mechanism as for normal system calls, returning the address of the asm function KiConvertToGuiThread, which is then called like an Nt* function would be called normally. KiConvertToGuiThread then allocated a new stack, switches to it (which is now fine, since all the code is asm), frees the old stack, calls PsConvertToGuiThread (which now will not try to allocate another stack, since we already have one) and then jumps into the middle of KiSystemCallEntry64, where the system call is handled again.
Also simplify KiSystemCallEntry64 a bit by copying the first parameters into the trap frame, avoiding to allocate additional stack space for the call to KiSystemCallHandler, which now overlaps with the space that is allocated for the Nt* function.
Finally fix the locations where r10 and r11 are stored, which is TrapFrame->Rcx and TrapFrame->EFlags, based on the situation in user mode.
This commit is contained in:
Timo Kreuzer 2018-02-06 20:52:16 +01:00
parent a6732905b8
commit 18b1aafd82
2 changed files with 100 additions and 62 deletions

View file

@ -356,22 +356,31 @@ NTSTATUS
NtSyscallFailure(void) NtSyscallFailure(void)
{ {
/* This is the failure function */ /* This is the failure function */
return STATUS_ACCESS_VIOLATION; return (NTSTATUS)KeGetCurrentThread()->TrapFrame->Rax;
} }
PVOID PVOID
KiSystemCallHandler( KiSystemCallHandler(
IN PKTRAP_FRAME TrapFrame, _In_ ULONG64 ReturnAddress,
IN ULONG64 P2, _In_ ULONG64 P2,
IN ULONG64 P3, _In_ ULONG64 P3,
IN ULONG64 P4) _In_ ULONG64 P4)
{ {
PKTRAP_FRAME TrapFrame;
PKSERVICE_TABLE_DESCRIPTOR DescriptorTable; PKSERVICE_TABLE_DESCRIPTOR DescriptorTable;
PKTHREAD Thread; PKTHREAD Thread;
PULONG64 KernelParams, UserParams; PULONG64 KernelParams, UserParams;
ULONG ServiceNumber, Offset, Count; ULONG ServiceNumber, Offset, Count;
ULONG64 UserRsp; ULONG64 UserRsp;
NTSTATUS Status;
/* Get a pointer to the trap frame */
TrapFrame = (PKTRAP_FRAME)((PULONG64)_AddressOfReturnAddress() + 1 + MAX_SYSCALL_PARAMS);
/* Save some values in the trap frame */
TrapFrame->Rip = ReturnAddress;
TrapFrame->Rdx = P2;
TrapFrame->R8 = P3;
TrapFrame->R9 = P4;
/* Increase system call count */ /* Increase system call count */
__addgsdword(FIELD_OFFSET(KIPCR, Prcb.KeSystemCalls), 1); __addgsdword(FIELD_OFFSET(KIPCR, Prcb.KeSystemCalls), 1);
@ -422,27 +431,12 @@ KiSystemCallHandler(
return (PVOID)NtSyscallFailure; return (PVOID)NtSyscallFailure;
} }
/* Convert us to a GUI thread -- must wrap in ASM to get new EBP */ /* Convert us to a GUI thread
Status = KiConvertToGuiThread(); To be entirely correct. we return KiConvertToGuiThread,
which allocates a new stack, switches to it, calls
/* Reload trap frame and descriptor table pointer from new stack */ PsConvertToGuiThread and resumes in the middle of
TrapFrame = *(volatile PVOID*)&Thread->TrapFrame; KiSystemCallEntry64 to restart the system call handling. */
DescriptorTable = (PVOID)(*(volatile ULONG_PTR*)&Thread->ServiceTable + Offset); return (PVOID)KiConvertToGuiThread;
if (!NT_SUCCESS(Status))
{
/* Set the last error and fail */
TrapFrame->Rax = Status;
return (PVOID)NtSyscallFailure;
}
/* Validate the system call number again */
if (ServiceNumber >= DescriptorTable->Limit)
{
/* Fail the call */
TrapFrame->Rax = STATUS_INVALID_SYSTEM_SERVICE;
return (PVOID)NtSyscallFailure;
}
} }
/* Get stack bytes and calculate argument count */ /* Get stack bytes and calculate argument count */
@ -464,10 +458,10 @@ KiSystemCallHandler(
case 7: KernelParams[6] = UserParams[6]; case 7: KernelParams[6] = UserParams[6];
case 6: KernelParams[5] = UserParams[5]; case 6: KernelParams[5] = UserParams[5];
case 5: KernelParams[4] = UserParams[4]; case 5: KernelParams[4] = UserParams[4];
case 4: KernelParams[3] = P4; case 4:
case 3: KernelParams[2] = P3; case 3:
case 2: KernelParams[1] = P2; case 2:
case 1: KernelParams[0] = TrapFrame->R10; case 1:
case 0: case 0:
break; break;

View file

@ -21,6 +21,9 @@ EXTERN KiXmmExceptionHandler:PROC
EXTERN KiDeliverApc:PROC EXTERN KiDeliverApc:PROC
EXTERN KiDpcInterruptHandler:PROC EXTERN KiDpcInterruptHandler:PROC
EXTERN PsConvertToGuiThread:PROC EXTERN PsConvertToGuiThread:PROC
EXTERN MmCreateKernelStack:PROC
EXTERN KeSwitchKernelStack:PROC
EXTERN MmDeleteKernelStack:PROC
#ifdef _WINKD_ #ifdef _WINKD_
EXTERN KdSetOwedBreakpoints:PROC EXTERN KdSetOwedBreakpoints:PROC
@ -720,8 +723,6 @@ ENDFUNC
#define MAX_SYSCALL_PARAM_SIZE (16 * 8) #define MAX_SYSCALL_PARAM_SIZE (16 * 8)
#define HOME_SIZE 6*8
#define SYSCALL_ALLOCATION (MAX_SYSCALL_PARAM_SIZE + HOME_SIZE)
EXTERN KiSystemCallHandler:PROC EXTERN KiSystemCallHandler:PROC
@ -752,41 +753,44 @@ PUBLIC KiSystemCallEntry64
mov rsp, gs:[PcRspBase] mov rsp, gs:[PcRspBase]
/* Allocate a TRAP_FRAME and space for parameters */ /* Allocate a TRAP_FRAME and space for parameters */
sub rsp, (KTRAP_FRAME_LENGTH + MAX_SYSCALL_PARAM_SIZE + HOME_SIZE) sub rsp, (KTRAP_FRAME_LENGTH + MAX_SYSCALL_PARAM_SIZE)
#if DBG #if DBG
/* Save rbp and load it with the old stack pointer */ /* Save rbp and load it with the old stack pointer */
mov [rsp + HOME_SIZE + MAX_SYSCALL_PARAM_SIZE + HOME_SIZE + KTRAP_FRAME_Rbp], rbp mov [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rbp], rbp
mov rbp, gs:[PcUserRsp] mov rbp, gs:[PcUserRsp]
#endif #endif
/* Save important volatiles in the trap frame */ /* Save important registers in the trap frame */
mov [rsp + HOME_SIZE + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rax], rax mov [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rax], rax
mov [rsp + HOME_SIZE + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rcx], rcx mov [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rcx], r10
mov [rsp + HOME_SIZE + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_R10], r10 mov [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_EFlags], r11
mov [rsp + HOME_SIZE + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_R11], r11
/* Set sane segments */ /* Set sane segments */
mov ax, (KGDT64_R3_DATA or RPL_MASK) mov ax, (KGDT64_R3_DATA or RPL_MASK)
mov ds, ax mov ds, ax
mov es, ax mov es, ax
.ENDP
.PROC KiSystemCall64Again
/* Old stack pointer is in rcx, lie and say we saved it in rbp */
.setframe rbp, 0
.endprolog
/* Call the C-handler (will enable interrupts) */ /* Call the C-handler (will enable interrupts) */
lea rcx, [rsp + SYSCALL_ALLOCATION]
call KiSystemCallHandler call KiSystemCallHandler
/* Deallocate the handlers home stack frame */ /* The return value from KiSystemCallHandler is the address of the Nt-function */
add rsp, HOME_SIZE mov rcx, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rcx]
mov rdx, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rdx]
/* The return value is the address of the Nt-function */ mov r8, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_R8]
mov rcx, [rsp + 0] mov r9, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_R9]
mov rdx, [rsp + 8]
mov r8, [rsp + 16]
mov r9, [rsp + 24]
call rax call rax
#if DBG #if DBG
/* Restore rbp */ /* Restore rbp */
mov rbp, [rsp + SYSCALL_ALLOCATION + KTRAP_FRAME_Rbp] mov rbp, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rbp]
test dword ptr [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_EFlags], HEX(200) test dword ptr [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_EFlags], HEX(200)
jnz IntsEnabled jnz IntsEnabled
@ -803,8 +807,8 @@ IntsEnabled:
mov [rcx + KTHREAD_TrapFrame], rdx mov [rcx + KTHREAD_TrapFrame], rdx
/* Prepare user mode return address (rcx) and eflags (r11) for sysret */ /* Prepare user mode return address (rcx) and eflags (r11) for sysret */
mov rcx, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rcx] mov rcx, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rip]
mov r11, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_R11] mov r11, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_EFlags]
/* Load user mode stack (It was copied to the trap frame) */ /* Load user mode stack (It was copied to the trap frame) */
mov rsp, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rsp] mov rsp, [rsp + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rsp]
@ -815,6 +819,7 @@ IntsEnabled:
/* return to user mode */ /* return to user mode */
.byte HEX(48) // REX prefix to return to long mode .byte HEX(48) // REX prefix to return to long mode
sysret sysret
.ENDP .ENDP
@ -896,22 +901,61 @@ ENDFUNC
PUBLIC KiConvertToGuiThread PUBLIC KiConvertToGuiThread
FUNC KiConvertToGuiThread FUNC KiConvertToGuiThread
push rbp sub rsp, 40
.pushreg rbp .allocstack 40
sub rsp, 32
.allocstack 32
.endprolog .endprolog
// NewStack = (ULONG_PTR)MmCreateKernelStack(TRUE, 0);
mov cl, 1
xor rdx, rdx
call MmCreateKernelStack
/* Check for failure */
test rax, rax
jz KiConvertToGuiThreadFailed
; OldStack = KeSwitchKernelStack((PVOID)NewStack, (PVOID)(NewStack - KERNEL_STACK_SIZE));
mov rcx, rax
mov rdx, rax
sub rdx, KERNEL_STACK_SIZE
call KeSwitchKernelStack
// MmDeleteKernelStack(OldStack, FALSE);
mov rcx, rax
xor rdx, rdx
call MmDeleteKernelStack
/* Call the worker function */ /* Call the worker function */
call PsConvertToGuiThread call PsConvertToGuiThread
/* Adjust rsp */ /* Check for failure */
add rsp, 32 test rax, rax
js KiConvertToGuiThreadFailed
/* Restore rbp */ /* Disable interrupts for return */
pop rbp cli
/* return to the caller */ // FIXME: should just do the trap frame switch in KiSystemCallHandler64
/* Restore old trap frame */
mov rcx, gs:[PcCurrentThread]
mov rdx, [rsp + 48 + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_TrapFrame]
mov [rcx + KTHREAD_TrapFrame], rdx
// Restore register parameters
mov rcx, [rsp + 48 + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rip]
mov rdx, [rsp + 48 + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_Rdx]
mov r8, [rsp + 48 + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_R8]
mov r9, [rsp + 48 + MAX_SYSCALL_PARAM_SIZE + KTRAP_FRAME_R9]
/* Run KiSystemCallHandler again */
add rsp, 48
jmp KiSystemCall64Again
KiConvertToGuiThreadFailed:
/* Clean up the stack and return failure */
add rsp, 40
mov eax, HEX(C0000017) // STATUS_NO_MEMORY
ret ret
ENDFUNC ENDFUNC
@ -977,7 +1021,7 @@ KiSwitchKernelStackHelper:
/* Return on new stack */ /* Return on new stack */
mov rax, rdx mov rax, rdx
ret; ret
#ifdef _MSC_VER #ifdef _MSC_VER