diff --git a/reactos/ntoskrnl/mm/ARM3/virtual.c b/reactos/ntoskrnl/mm/ARM3/virtual.c index 18295d89d01..1be2106b827 100644 --- a/reactos/ntoskrnl/mm/ARM3/virtual.c +++ b/reactos/ntoskrnl/mm/ARM3/virtual.c @@ -26,6 +26,15 @@ MiProtectVirtualMemory(IN PEPROCESS Process, IN ULONG NewAccessProtection, OUT PULONG OldAccessProtection OPTIONAL); +VOID +NTAPI +MiFlushTbAndCapture(IN PMMVAD FoundVad, + IN PMMPTE PointerPte, + IN ULONG ProtectionMask, + IN PMMPFN Pfn1, + IN BOOLEAN CaptureDirtyBit); + + /* PRIVATE FUNCTIONS **********************************************************/ ULONG @@ -1719,7 +1728,17 @@ MiProtectVirtualMemory(IN PEPROCESS Process, OUT PULONG OldAccessProtection OPTIONAL) { PMEMORY_AREA MemoryArea; + PMMVAD Vad; + PMMSUPPORT AddressSpace; + ULONG_PTR StartingAddress, EndingAddress; + PMMPTE PointerPde, PointerPte, LastPte; + MMPTE PteContents; + PUSHORT UsedPageTableEntries; + PMMPFN Pfn1; + ULONG ProtectionMask; + NTSTATUS Status = STATUS_SUCCESS; + /* Check for ROS specific memory area */ MemoryArea = MmLocateMemoryAreaByAddress(&Process->Vm, *BaseAddress); if ((MemoryArea) && (MemoryArea->Type == MEMORY_AREA_SECTION_VIEW)) { @@ -1730,8 +1749,175 @@ MiProtectVirtualMemory(IN PEPROCESS Process, OldAccessProtection); } - UNIMPLEMENTED; - return STATUS_CONFLICTING_ADDRESSES; + /* Calcualte base address for the VAD */ + StartingAddress = (ULONG_PTR)PAGE_ALIGN((*BaseAddress)); + EndingAddress = (((ULONG_PTR)*BaseAddress + *NumberOfBytesToProtect - 1) | (PAGE_SIZE - 1)); + + /* Calculate the protection mask and make sure it's valid */ + ProtectionMask = MiMakeProtectionMask(NewAccessProtection); + if (ProtectionMask == MM_INVALID_PROTECTION) + { + DPRINT1("Invalid protection mask\n"); + return STATUS_INVALID_PAGE_PROTECTION; + } + + /* Lock the address space and make sure the process isn't already dead */ + AddressSpace = MmGetCurrentAddressSpace(); + MmLockAddressSpace(AddressSpace); + if (Process->VmDeleted) + { + DPRINT1("Process is dying\n"); + Status = STATUS_PROCESS_IS_TERMINATING; + goto FailPath; + } + + /* Get the VAD for this address range, and make sure it exists */ + Vad = (PMMVAD)MiCheckForConflictingNode(StartingAddress >> PAGE_SHIFT, + EndingAddress >> PAGE_SHIFT, + &Process->VadRoot); + if (!Vad) + { + DPRINT("Could not find a VAD for this allocation\n"); + Status = STATUS_CONFLICTING_ADDRESSES; + goto FailPath; + } + + /* Make sure the address is within this VAD's boundaries */ + if ((((ULONG_PTR)StartingAddress >> PAGE_SHIFT) < Vad->StartingVpn) || + (((ULONG_PTR)EndingAddress >> PAGE_SHIFT) > Vad->EndingVpn)) + { + Status = STATUS_CONFLICTING_ADDRESSES; + goto FailPath; + } + + /* These kinds of VADs are not supported atm */ + if ((Vad->u.VadFlags.VadType == VadAwe) || + (Vad->u.VadFlags.VadType == VadDevicePhysicalMemory) || + (Vad->u.VadFlags.VadType == VadLargePages)) + { + DPRINT1("Illegal VAD for attempting to set protection\n"); + Status = STATUS_CONFLICTING_ADDRESSES; + goto FailPath; + } + + /* Check for a VAD whose protection can't be changed */ + if (Vad->u.VadFlags.NoChange == 1) + { + DPRINT1("Trying to change protection of a NoChange VAD\n"); + Status = STATUS_INVALID_PAGE_PROTECTION; + goto FailPath; + } + + if (Vad->u.VadFlags.PrivateMemory == 0) + { + /* This is a section, handled by the ROS specific code above */ + UNIMPLEMENTED; + } + else + { + /* Private memory, check protection flags */ + if ((NewAccessProtection & PAGE_WRITECOPY) || + (NewAccessProtection & PAGE_EXECUTE_WRITECOPY)) + { + Status = STATUS_INVALID_PARAMETER_4; + goto FailPath; + } + + //MiLockProcessWorkingSet(Thread, Process); + + /* TODO: Check if all pages in this range are committed */ + + /* Compute starting and ending PTE and PDE addresses */ + PointerPde = MiAddressToPde(StartingAddress); + PointerPte = MiAddressToPte(StartingAddress); + LastPte = MiAddressToPte(EndingAddress); + + /* Make this PDE valid */ + MiMakePdeExistAndMakeValid(PointerPde, Process, MM_NOIRQL); + + /* Save protection of the first page */ + if (PointerPte->u.Long != 0) + { + /* Capture the page protection and make the PDE valid */ + *OldAccessProtection = MiGetPageProtection(PointerPte); + MiMakePdeExistAndMakeValid(PointerPde, Process, MM_NOIRQL); + } + else + { + /* Grab the old protection from the VAD itself */ + *OldAccessProtection = MmProtectToValue[Vad->u.VadFlags.Protection]; + } + + /* Loop all the PTEs now */ + while (PointerPte <= LastPte) + { + /* Check if we've crossed a PDE boundary and make the new PDE valid too */ + if ((((ULONG_PTR)PointerPte) & (SYSTEM_PD_SIZE - 1)) == 0) + { + PointerPde = MiAddressToPte(PointerPte); + MiMakePdeExistAndMakeValid(PointerPde, Process, MM_NOIRQL); + } + + /* Capture the PTE and see what we're dealing with */ + PteContents = *PointerPte; + if (PteContents.u.Long == 0) + { + /* This used to be a zero PTE and it no longer is, so we must add a + reference to the pagetable. */ + UsedPageTableEntries = &MmWorkingSetList->UsedPageTableEntries[MiGetPdeOffset(MiPteToAddress(PointerPte))]; + (*UsedPageTableEntries)++; + ASSERT((*UsedPageTableEntries) <= PTE_COUNT); + } + else if (PteContents.u.Hard.Valid == 1) + { + /* Get the PFN entry */ + Pfn1 = MiGetPfnEntry(PFN_FROM_PTE(&PteContents)); + + /* We don't support this yet */ + ASSERT(Pfn1->u3.e1.PrototypePte == 0); + + /* Check if the page should not be accessible at all */ + if ((NewAccessProtection & PAGE_NOACCESS) || + (NewAccessProtection & PAGE_GUARD)) + { + /* TODO */ + UNIMPLEMENTED; + } + + /* Write the protection mask and write it with a TLB flush */ + Pfn1->OriginalPte.u.Soft.Protection = ProtectionMask; + MiFlushTbAndCapture(Vad, + PointerPte, + ProtectionMask, + Pfn1, + TRUE); + } + else + { + /* We don't support these cases yet */ + ASSERT(PteContents.u.Soft.Prototype == 0); + ASSERT(PteContents.u.Soft.Transition == 0); + + /* The PTE is already demand-zero, just update the protection mask */ + PointerPte->u.Soft.Protection = ProtectionMask; + } + + PointerPte++; + } + + /* Unlock the working set and update quota charges if needed, then return */ + //MiUnlockProcessWorkingSet(Thread, Process); + } + +FailPath: + /* Unlock the address space */ + MmUnlockAddressSpace(AddressSpace); + + /* Return parameters */ + *NumberOfBytesToProtect = (SIZE_T)((PUCHAR)EndingAddress - (PUCHAR)StartingAddress + 1); + *BaseAddress = (PVOID)StartingAddress; + + return Status; } VOID