diff --git a/include/linux/mm.h b/include/linux/mm.h index d64f2626d44c..1cf6e5d14c63 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -636,6 +636,7 @@ static inline void INIT_VMA(struct vm_area_struct *vma) INIT_LIST_HEAD(&vma->anon_vma_chain); #ifdef CONFIG_SPECULATIVE_PAGE_FAULT seqcount_init(&vma->vm_sequence); + atomic_set(&vma->vm_ref_count, 1); #endif } diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index 7e0cffe6bdfb..ba11a7b2d76a 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -381,6 +381,7 @@ struct vm_area_struct { struct vm_userfaultfd_ctx vm_userfaultfd_ctx; #ifdef CONFIG_SPECULATIVE_PAGE_FAULT seqcount_t vm_sequence; + atomic_t vm_ref_count; /* see vma_get(), vma_put() */ #endif } __randomize_layout; @@ -401,6 +402,9 @@ struct mm_struct { struct vm_area_struct *mmap; /* list of VMAs */ struct rb_root mm_rb; u64 vmacache_seqnum; /* per-thread vmacache */ +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + rwlock_t mm_rb_lock; +#endif #ifdef CONFIG_MMU unsigned long (*get_unmapped_area) (struct file *filp, unsigned long addr, unsigned long len, diff --git a/kernel/fork.c b/kernel/fork.c index 8ba76a608d79..5f4e4ce703c5 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -1008,6 +1008,9 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, mm->mmap = NULL; mm->mm_rb = RB_ROOT; mm->vmacache_seqnum = 0; +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + rwlock_init(&mm->mm_rb_lock); +#endif atomic_set(&mm->mm_users, 1); atomic_set(&mm->mm_count, 1); seqcount_init(&mm->write_protect_seq); diff --git a/mm/init-mm.c b/mm/init-mm.c index 153162669f80..241def24524f 100644 --- a/mm/init-mm.c +++ b/mm/init-mm.c @@ -28,6 +28,9 @@ */ struct mm_struct init_mm = { .mm_rb = RB_ROOT, +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + .mm_rb_lock = __RW_LOCK_UNLOCKED(init_mm.mm_rb_lock), +#endif .pgd = swapper_pg_dir, .mm_users = ATOMIC_INIT(2), .mm_count = ATOMIC_INIT(1), diff --git a/mm/internal.h b/mm/internal.h index c43ccdddb0f6..4b156e129def 100644 --- a/mm/internal.h +++ b/mm/internal.h @@ -36,6 +36,12 @@ void page_writeback_init(void); vm_fault_t do_swap_page(struct vm_fault *vmf); +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +extern struct vm_area_struct *get_vma(struct mm_struct *mm, + unsigned long addr); +extern void put_vma(struct vm_area_struct *vma); +#endif + void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *start_vma, unsigned long floor, unsigned long ceiling); diff --git a/mm/mmap.c b/mm/mmap.c index 0a15dc814751..ddc5725d5073 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -168,6 +168,27 @@ void unlink_file_vma(struct vm_area_struct *vma) } } +static void __free_vma(struct vm_area_struct *vma) +{ + if (vma->vm_file) + fput(vma->vm_file); + mpol_put(vma_policy(vma)); + vm_area_free(vma); +} + +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +void put_vma(struct vm_area_struct *vma) +{ + if (atomic_dec_and_test(&vma->vm_ref_count)) + __free_vma(vma); +} +#else +static inline void put_vma(struct vm_area_struct *vma) +{ + __free_vma(vma); +} +#endif + /* * Close a vm structure and free it, returning the next. */ @@ -178,10 +199,7 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma) might_sleep(); if (vma->vm_ops && vma->vm_ops->close) vma->vm_ops->close(vma); - if (vma->vm_file) - fput(vma->vm_file); - mpol_put(vma_policy(vma)); - vm_area_free(vma); + put_vma(vma); return next; } @@ -434,6 +452,13 @@ static void validate_mm(struct mm_struct *mm) RB_DECLARE_CALLBACKS_MAX(static, vma_gap_callbacks, struct vm_area_struct, vm_rb, unsigned long, rb_subtree_gap, vma_compute_gap) +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +#define mm_rb_write_lock(mm) write_lock(&(mm)->mm_rb_lock) +#define mm_rb_write_unlock(mm) write_unlock(&(mm)->mm_rb_lock) +#else +#define mm_rb_write_lock(mm) do { } while (0) +#define mm_rb_write_unlock(mm) do { } while (0) +#endif /* CONFIG_SPECULATIVE_PAGE_FAULT */ /* * Update augmented rbtree rb_subtree_gap values after vma->vm_start or @@ -450,26 +475,37 @@ static void vma_gap_update(struct vm_area_struct *vma) } static inline void vma_rb_insert(struct vm_area_struct *vma, - struct rb_root *root) + struct mm_struct *mm) { + struct rb_root *root = &mm->mm_rb; + /* All rb_subtree_gap values must be consistent prior to insertion */ validate_mm_rb(root, NULL); rb_insert_augmented(&vma->vm_rb, root, &vma_gap_callbacks); } -static void __vma_rb_erase(struct vm_area_struct *vma, struct rb_root *root) +static void __vma_rb_erase(struct vm_area_struct *vma, struct mm_struct *mm) { + struct rb_root *root = &mm->mm_rb; /* * Note rb_erase_augmented is a fairly large inline function, * so make sure we instantiate it only once with our desired * augmented rbtree callbacks. */ + mm_rb_write_lock(mm); rb_erase_augmented(&vma->vm_rb, root, &vma_gap_callbacks); + mm_rb_write_unlock(mm); /* wmb */ + + /* + * Ensure the removal is complete before clearing the node. + * Matched by vma_has_changed()/handle_speculative_fault(). + */ + RB_CLEAR_NODE(&vma->vm_rb); } static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma, - struct rb_root *root, + struct mm_struct *mm, struct vm_area_struct *ignore) { /* @@ -481,15 +517,15 @@ static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma, * b. the vma being erased in detach_vmas_to_be_unmapped() -> * vma_rb_erase() */ - validate_mm_rb(root, ignore); + validate_mm_rb(&mm->mm_rb, ignore); - __vma_rb_erase(vma, root); + __vma_rb_erase(vma, mm); } static __always_inline void vma_rb_erase(struct vm_area_struct *vma, - struct rb_root *root) + struct mm_struct *mm) { - vma_rb_erase_ignore(vma, root, vma); + vma_rb_erase_ignore(vma, mm, vma); } /* @@ -648,10 +684,12 @@ void __vma_link_rb(struct mm_struct *mm, struct vm_area_struct *vma, * immediately update the gap to the correct value. Finally we * rebalance the rbtree after all augmented values have been set. */ + mm_rb_write_lock(mm); rb_link_node(&vma->vm_rb, rb_parent, rb_link); vma->rb_subtree_gap = 0; vma_gap_update(vma); - vma_rb_insert(vma, &mm->mm_rb); + vma_rb_insert(vma, mm); + mm_rb_write_unlock(mm); } static void __vma_link_file(struct vm_area_struct *vma) @@ -723,7 +761,7 @@ static __always_inline void __vma_unlink(struct mm_struct *mm, struct vm_area_struct *vma, struct vm_area_struct *ignore) { - vma_rb_erase_ignore(vma, &mm->mm_rb, ignore); + vma_rb_erase_ignore(vma, mm, ignore); __vma_unlink_list(mm, vma); /* Kill the cache */ vmacache_invalidate(mm); @@ -979,16 +1017,13 @@ int __vma_adjust(struct vm_area_struct *vma, unsigned long start, } if (remove_next) { - if (file) { + if (file) uprobe_munmap(next, next->vm_start, next->vm_end); - fput(file); - } if (next->anon_vma) anon_vma_merge(vma, next); mm->map_count--; - mpol_put(vma_policy(next)); vm_raw_write_end(next); - vm_area_free(next); + put_vma(next); /* * In mprotect's case 6 (see comments on vma_merge), * we must remove another next too. It would clutter @@ -2351,15 +2386,11 @@ get_unmapped_area(struct file *file, unsigned long addr, unsigned long len, EXPORT_SYMBOL(get_unmapped_area); /* Look up the first VMA which satisfies addr < vm_end, NULL if none. */ -struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) +static struct vm_area_struct *__find_vma(struct mm_struct *mm, + unsigned long addr) { struct rb_node *rb_node; - struct vm_area_struct *vma; - - /* Check the cache first. */ - vma = vmacache_find(mm, addr); - if (likely(vma)) - return vma; + struct vm_area_struct *vma = NULL; rb_node = mm->mm_rb.rb_node; @@ -2377,13 +2408,40 @@ struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) rb_node = rb_node->rb_right; } + return vma; +} + +struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr) +{ + struct vm_area_struct *vma; + + /* Check the cache first. */ + vma = vmacache_find(mm, addr); + if (likely(vma)) + return vma; + + vma = __find_vma(mm, addr); if (vma) vmacache_update(addr, vma); return vma; } - EXPORT_SYMBOL(find_vma); +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT +struct vm_area_struct *get_vma(struct mm_struct *mm, unsigned long addr) +{ + struct vm_area_struct *vma = NULL; + + read_lock(&mm->mm_rb_lock); + vma = __find_vma(mm, addr); + if (vma) + atomic_inc(&vma->vm_ref_count); + read_unlock(&mm->mm_rb_lock); + + return vma; +} +#endif + /* * Same as find_vma, but also return a pointer to the previous VMA in *pprev. */ @@ -2747,7 +2805,7 @@ detach_vmas_to_be_unmapped(struct mm_struct *mm, struct vm_area_struct *vma, insertion_point = (prev ? &prev->vm_next : &mm->mmap); vma->vm_prev = NULL; do { - vma_rb_erase(vma, &mm->mm_rb); + vma_rb_erase(vma, mm); mm->map_count--; tail_vma = vma; vma = vma->vm_next;