diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 90e51fa3696b..8cc9faac4b9c 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -358,18 +358,26 @@ static inline long userfaultfd_get_blocking_state(unsigned int flags) } #ifdef CONFIG_SPECULATIVE_PAGE_FAULT -bool userfaultfd_using_sigbus(struct vm_area_struct *vma) +bool userfaultfd_using_sigbus(struct vm_fault *vmf) { - struct userfaultfd_ctx *ctx; - bool ret; + bool ret = false; /* * Do it inside RCU section to ensure that the ctx doesn't * disappear under us. */ rcu_read_lock(); - ctx = rcu_dereference(vma->vm_userfaultfd_ctx.ctx); - ret = ctx && (ctx->features & UFFD_FEATURE_SIGBUS); + /* + * Ensure that we are not looking at dangling pointer to + * userfaultfd_ctx, which could happen if userfaultfd_release() is + * called and vma is unlinked. + */ + if (!vma_has_changed(vmf)) { + struct userfaultfd_ctx *ctx; + + ctx = rcu_dereference(vmf->vma->vm_userfaultfd_ctx.ctx); + ret = ctx && (ctx->features & UFFD_FEATURE_SIGBUS); + } rcu_read_unlock(); return ret; } diff --git a/include/linux/mm.h b/include/linux/mm.h index dfefcfa1d6a4..714878c6814f 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -1776,6 +1776,20 @@ static inline void vm_write_end(struct vm_area_struct *vma) { raw_write_seqcount_end(&vma->vm_sequence); } + +static inline bool vma_has_changed(struct vm_fault *vmf) +{ + int ret = RB_EMPTY_NODE(&vmf->vma->vm_rb); + unsigned int seq = READ_ONCE(vmf->vma->vm_sequence.sequence); + + /* + * Matches both the wmb in write_seqlock_{begin,end}() and + * the wmb in vma_rb_erase(). + */ + smp_rmb(); + + return ret || seq != vmf->sequence; +} #else static inline void vm_write_begin(struct vm_area_struct *vma) { diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h index c8d776bee7e7..0c61fd3916dd 100644 --- a/include/linux/userfaultfd_k.h +++ b/include/linux/userfaultfd_k.h @@ -40,7 +40,7 @@ extern int sysctl_unprivileged_userfaultfd; extern vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason); #ifdef CONFIG_SPECULATIVE_PAGE_FAULT -extern bool userfaultfd_using_sigbus(struct vm_area_struct *vma); +extern bool userfaultfd_using_sigbus(struct vm_fault *vmf); #endif /* diff --git a/mm/internal.h b/mm/internal.h index 5a3f3725d306..a06e45c82aed 100644 --- a/mm/internal.h +++ b/mm/internal.h @@ -40,20 +40,6 @@ vm_fault_t do_swap_page(struct vm_fault *vmf); extern struct vm_area_struct *get_vma(struct mm_struct *mm, unsigned long addr); extern void put_vma(struct vm_area_struct *vma); - -static inline bool vma_has_changed(struct vm_fault *vmf) -{ - int ret = RB_EMPTY_NODE(&vmf->vma->vm_rb); - unsigned int seq = READ_ONCE(vmf->vma->vm_sequence.sequence); - - /* - * Matches both the wmb in write_seqlock_{begin,end}() and - * the wmb in vma_rb_erase(). - */ - smp_rmb(); - - return ret || seq != vmf->sequence; -} #endif /* CONFIG_SPECULATIVE_PAGE_FAULT */ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *start_vma, diff --git a/mm/memory.c b/mm/memory.c index 3bcfce7c6b96..3d0317538e25 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -5058,6 +5058,7 @@ static vm_fault_t ___handle_speculative_fault(struct mm_struct *mm, vmf.vma_flags = READ_ONCE(vmf.vma->vm_flags); vmf.vma_page_prot = READ_ONCE(vmf.vma->vm_page_prot); + vmf.sequence = seq; #ifdef CONFIG_USERFAULTFD /* @@ -5067,7 +5068,7 @@ static vm_fault_t ___handle_speculative_fault(struct mm_struct *mm, if (unlikely(vmf.vma_flags & __VM_UFFD_FLAGS)) { uffd_missing_sigbus = vma_is_anonymous(vmf.vma) && (vmf.vma_flags & VM_UFFD_MISSING) && - userfaultfd_using_sigbus(vmf.vma); + userfaultfd_using_sigbus(&vmf); if (!uffd_missing_sigbus) { trace_spf_vma_notsup(_RET_IP_, vmf.vma, address); return VM_FAULT_RETRY; @@ -5193,7 +5194,6 @@ static vm_fault_t ___handle_speculative_fault(struct mm_struct *mm, vmf.pte = NULL; } - vmf.sequence = seq; vmf.flags = flags; local_irq_enable();