From: Jann Horn jannh@google.com
commit 16d51a590a8ce3befb1308e0e7ab77f3b661af33 upstream.
When going through execve(), zero out the NUMA fault statistics instead of freeing them.
During execve, the task is reachable through procfs and the scheduler. A concurrent /proc/*/sched reader can read data from a freed ->numa_faults allocation (confirmed by KASAN) and write it back to userspace. I believe that it would also be possible for a use-after-free read to occur through a race between a NUMA fault and execve(): task_numa_fault() can lead to task_numa_compare(), which invokes task_weight() on the currently running task of a different CPU.
Another way to fix this would be to make ->numa_faults RCU-managed or add extra locking, but it seems easier to wipe the NUMA fault statistics on execve.
Signed-off-by: Jann Horn jannh@google.com Signed-off-by: Peter Zijlstra (Intel) peterz@infradead.org Cc: Linus Torvalds torvalds@linux-foundation.org Cc: Peter Zijlstra peterz@infradead.org Cc: Petr Mladek pmladek@suse.com Cc: Sergey Senozhatsky sergey.senozhatsky@gmail.com Cc: Thomas Gleixner tglx@linutronix.de Cc: Will Deacon will@kernel.org Fixes: 82727018b0d3 ("sched/numa: Call task_numa_free() from do_execve()") Link: https://lkml.kernel.org/r/20190716152047.14424-1-jannh@google.com Signed-off-by: Ingo Molnar mingo@kernel.org Signed-off-by: Greg Kroah-Hartman gregkh@linuxfoundation.org
--- fs/exec.c | 2 +- include/linux/sched/numa_balancing.h | 4 ++-- kernel/fork.c | 2 +- kernel/sched/fair.c | 24 ++++++++++++++++++++---- 4 files changed, 24 insertions(+), 8 deletions(-)
--- a/fs/exec.c +++ b/fs/exec.c @@ -1828,7 +1828,7 @@ static int __do_execve_file(int fd, stru membarrier_execve(current); rseq_execve(current); acct_update_integrals(current); - task_numa_free(current); + task_numa_free(current, false); free_bprm(bprm); kfree(pathbuf); if (filename) --- a/include/linux/sched/numa_balancing.h +++ b/include/linux/sched/numa_balancing.h @@ -19,7 +19,7 @@ extern void task_numa_fault(int last_node, int node, int pages, int flags); extern pid_t task_numa_group_id(struct task_struct *p); extern void set_numabalancing_state(bool enabled); -extern void task_numa_free(struct task_struct *p); +extern void task_numa_free(struct task_struct *p, bool final); extern bool should_numa_migrate_memory(struct task_struct *p, struct page *page, int src_nid, int dst_cpu); #else @@ -34,7 +34,7 @@ static inline pid_t task_numa_group_id(s static inline void set_numabalancing_state(bool enabled) { } -static inline void task_numa_free(struct task_struct *p) +static inline void task_numa_free(struct task_struct *p, bool final) { } static inline bool should_numa_migrate_memory(struct task_struct *p, --- a/kernel/fork.c +++ b/kernel/fork.c @@ -727,7 +727,7 @@ void __put_task_struct(struct task_struc WARN_ON(tsk == current);
cgroup_free(tsk); - task_numa_free(tsk); + task_numa_free(tsk, true); security_task_free(tsk); exit_creds(tsk); delayacct_tsk_free(tsk); --- a/kernel/sched/fair.c +++ b/kernel/sched/fair.c @@ -2336,13 +2336,23 @@ no_join: return; }
-void task_numa_free(struct task_struct *p) +/* + * Get rid of NUMA staticstics associated with a task (either current or dead). + * If @final is set, the task is dead and has reached refcount zero, so we can + * safely free all relevant data structures. Otherwise, there might be + * concurrent reads from places like load balancing and procfs, and we should + * reset the data back to default state without freeing ->numa_faults. + */ +void task_numa_free(struct task_struct *p, bool final) { struct numa_group *grp = p->numa_group; - void *numa_faults = p->numa_faults; + unsigned long *numa_faults = p->numa_faults; unsigned long flags; int i;
+ if (!numa_faults) + return; + if (grp) { spin_lock_irqsave(&grp->lock, flags); for (i = 0; i < NR_NUMA_HINT_FAULT_STATS * nr_node_ids; i++) @@ -2355,8 +2365,14 @@ void task_numa_free(struct task_struct * put_numa_group(grp); }
- p->numa_faults = NULL; - kfree(numa_faults); + if (final) { + p->numa_faults = NULL; + kfree(numa_faults); + } else { + p->total_numa_faults = 0; + for (i = 0; i < NR_NUMA_HINT_FAULT_STATS * nr_node_ids; i++) + numa_faults[i] = 0; + } }
/*