diff --git a/arch/arm64/include/asm/fpsimd.h b/arch/arm64/include/asm/fpsimd.h index 6f86b7ab6c28..d720b6f7e5f9 100644 --- a/arch/arm64/include/asm/fpsimd.h +++ b/arch/arm64/include/asm/fpsimd.h @@ -339,7 +339,7 @@ static inline int sme_max_virtualisable_vl(void) return vec_max_virtualisable_vl(ARM64_VEC_SME); } -extern void sme_alloc(struct task_struct *task); +extern void sme_alloc(struct task_struct *task, bool flush); extern unsigned int sme_get_vl(void); extern int sme_set_current_vl(unsigned long arg); extern int sme_get_current_vl(void); @@ -365,7 +365,7 @@ static inline void sme_smstart_sm(void) { } static inline void sme_smstop_sm(void) { } static inline void sme_smstop(void) { } -static inline void sme_alloc(struct task_struct *task) { } +static inline void sme_alloc(struct task_struct *task, bool flush) { } static inline void sme_setup(void) { } static inline unsigned int sme_get_vl(void) { return 0; } static inline int sme_max_vl(void) { return 0; } diff --git a/arch/arm64/kernel/fpsimd.c b/arch/arm64/kernel/fpsimd.c index 356036babd09..8cd59d387b90 100644 --- a/arch/arm64/kernel/fpsimd.c +++ b/arch/arm64/kernel/fpsimd.c @@ -1239,9 +1239,9 @@ void fpsimd_release_task(struct task_struct *dead_task) * the interest of testability and predictability, the architecture * guarantees that when ZA is enabled it will be zeroed. */ -void sme_alloc(struct task_struct *task) +void sme_alloc(struct task_struct *task, bool flush) { - if (task->thread.za_state) { + if (task->thread.za_state && flush) { memset(task->thread.za_state, 0, za_state_size(task)); return; } @@ -1460,7 +1460,7 @@ void do_sme_acc(unsigned long esr, struct pt_regs *regs) } sve_alloc(current, false); - sme_alloc(current); + sme_alloc(current, true); if (!current->thread.sve_state || !current->thread.za_state) { force_sig(SIGKILL); return; diff --git a/arch/arm64/kernel/ptrace.c b/arch/arm64/kernel/ptrace.c index f19f020ccff9..f606c942f514 100644 --- a/arch/arm64/kernel/ptrace.c +++ b/arch/arm64/kernel/ptrace.c @@ -886,6 +886,13 @@ static int sve_set_common(struct task_struct *target, break; case ARM64_VEC_SME: target->thread.svcr |= SVCR_SM_MASK; + + /* + * Disable traps and ensure there is SME storage but + * preserve any currently set values in ZA/ZT. + */ + sme_alloc(target, false); + set_tsk_thread_flag(target, TIF_SME); break; default: WARN_ON_ONCE(1); @@ -1107,7 +1114,7 @@ static int za_set(struct task_struct *target, } /* Allocate/reinit ZA storage */ - sme_alloc(target); + sme_alloc(target, true); if (!target->thread.za_state) { ret = -ENOMEM; goto out; diff --git a/arch/arm64/kernel/signal.c b/arch/arm64/kernel/signal.c index 43adbfa5ead7..82f4572c8ddf 100644 --- a/arch/arm64/kernel/signal.c +++ b/arch/arm64/kernel/signal.c @@ -430,7 +430,7 @@ static int restore_za_context(struct user_ctxs *user) fpsimd_flush_task_state(current); /* From now, fpsimd_thread_switch() won't touch thread.sve_state */ - sme_alloc(current); + sme_alloc(current, true); if (!current->thread.za_state) { current->thread.svcr &= ~SVCR_ZA_MASK; clear_thread_flag(TIF_SME);