@@ -1668,7 +1668,8 @@ static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
16681668 WRITE_ONCE (cqe -> user_data , req -> user_data );
16691669 WRITE_ONCE (cqe -> res , res );
16701670 WRITE_ONCE (cqe -> flags , cflags );
1671- } else if (ctx -> cq_overflow_flushed || req -> task -> io_uring -> in_idle ) {
1671+ } else if (ctx -> cq_overflow_flushed ||
1672+ atomic_read (& req -> task -> io_uring -> in_idle )) {
16721673 /*
16731674 * If we're in ring overflow flush mode, or in task cancel mode,
16741675 * then we cannot store the request for later flushing, we need
@@ -1838,7 +1839,7 @@ static void __io_free_req(struct io_kiocb *req)
18381839 io_dismantle_req (req );
18391840
18401841 percpu_counter_dec (& tctx -> inflight );
1841- if (tctx -> in_idle )
1842+ if (atomic_read ( & tctx -> in_idle ) )
18421843 wake_up (& tctx -> wait );
18431844 put_task_struct (req -> task );
18441845
@@ -7695,7 +7696,8 @@ static int io_uring_alloc_task_context(struct task_struct *task)
76957696 xa_init (& tctx -> xa );
76967697 init_waitqueue_head (& tctx -> wait );
76977698 tctx -> last = NULL ;
7698- tctx -> in_idle = 0 ;
7699+ atomic_set (& tctx -> in_idle , 0 );
7700+ tctx -> sqpoll = false;
76997701 io_init_identity (& tctx -> __identity );
77007702 tctx -> identity = & tctx -> __identity ;
77017703 task -> io_uring = tctx ;
@@ -8598,21 +8600,35 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
85988600{
85998601 struct task_struct * task = current ;
86008602
8601- if ((ctx -> flags & IORING_SETUP_SQPOLL ) && ctx -> sq_data )
8603+ if ((ctx -> flags & IORING_SETUP_SQPOLL ) && ctx -> sq_data ) {
86028604 task = ctx -> sq_data -> thread ;
8605+ atomic_inc (& task -> io_uring -> in_idle );
8606+ io_sq_thread_park (ctx -> sq_data );
8607+ }
86038608
86048609 io_cqring_overflow_flush (ctx , true, task , files );
86058610
86068611 while (__io_uring_cancel_task_requests (ctx , task , files )) {
86078612 io_run_task_work ();
86088613 cond_resched ();
86098614 }
8615+
8616+ if ((ctx -> flags & IORING_SETUP_SQPOLL ) && ctx -> sq_data ) {
8617+ atomic_dec (& task -> io_uring -> in_idle );
8618+ /*
8619+ * If the files that are going away are the ones in the thread
8620+ * identity, clear them out.
8621+ */
8622+ if (task -> io_uring -> identity -> files == files )
8623+ task -> io_uring -> identity -> files = NULL ;
8624+ io_sq_thread_unpark (ctx -> sq_data );
8625+ }
86108626}
86118627
86128628/*
86138629 * Note that this task has used io_uring. We use it for cancelation purposes.
86148630 */
8615- static int io_uring_add_task_file (struct file * file )
8631+ static int io_uring_add_task_file (struct io_ring_ctx * ctx , struct file * file )
86168632{
86178633 struct io_uring_task * tctx = current -> io_uring ;
86188634
@@ -8634,6 +8650,14 @@ static int io_uring_add_task_file(struct file *file)
86348650 tctx -> last = file ;
86358651 }
86368652
8653+ /*
8654+ * This is race safe in that the task itself is doing this, hence it
8655+ * cannot be going through the exit/cancel paths at the same time.
8656+ * This cannot be modified while exit/cancel is running.
8657+ */
8658+ if (!tctx -> sqpoll && (ctx -> flags & IORING_SETUP_SQPOLL ))
8659+ tctx -> sqpoll = true;
8660+
86378661 return 0 ;
86388662}
86398663
@@ -8675,7 +8699,7 @@ void __io_uring_files_cancel(struct files_struct *files)
86758699 unsigned long index ;
86768700
86778701 /* make sure overflow events are dropped */
8678- tctx -> in_idle = true ;
8702+ atomic_inc ( & tctx -> in_idle ) ;
86798703
86808704 xa_for_each (& tctx -> xa , index , file ) {
86818705 struct io_ring_ctx * ctx = file -> private_data ;
@@ -8684,6 +8708,35 @@ void __io_uring_files_cancel(struct files_struct *files)
86848708 if (files )
86858709 io_uring_del_task_file (file );
86868710 }
8711+
8712+ atomic_dec (& tctx -> in_idle );
8713+ }
8714+
8715+ static s64 tctx_inflight (struct io_uring_task * tctx )
8716+ {
8717+ unsigned long index ;
8718+ struct file * file ;
8719+ s64 inflight ;
8720+
8721+ inflight = percpu_counter_sum (& tctx -> inflight );
8722+ if (!tctx -> sqpoll )
8723+ return inflight ;
8724+
8725+ /*
8726+ * If we have SQPOLL rings, then we need to iterate and find them, and
8727+ * add the pending count for those.
8728+ */
8729+ xa_for_each (& tctx -> xa , index , file ) {
8730+ struct io_ring_ctx * ctx = file -> private_data ;
8731+
8732+ if (ctx -> flags & IORING_SETUP_SQPOLL ) {
8733+ struct io_uring_task * __tctx = ctx -> sqo_task -> io_uring ;
8734+
8735+ inflight += percpu_counter_sum (& __tctx -> inflight );
8736+ }
8737+ }
8738+
8739+ return inflight ;
86878740}
86888741
86898742/*
@@ -8697,11 +8750,11 @@ void __io_uring_task_cancel(void)
86978750 s64 inflight ;
86988751
86998752 /* make sure overflow events are dropped */
8700- tctx -> in_idle = true ;
8753+ atomic_inc ( & tctx -> in_idle ) ;
87018754
87028755 do {
87038756 /* read completions before cancelations */
8704- inflight = percpu_counter_sum ( & tctx -> inflight );
8757+ inflight = tctx_inflight ( tctx );
87058758 if (!inflight )
87068759 break ;
87078760 __io_uring_files_cancel (NULL );
@@ -8712,13 +8765,13 @@ void __io_uring_task_cancel(void)
87128765 * If we've seen completions, retry. This avoids a race where
87138766 * a completion comes in before we did prepare_to_wait().
87148767 */
8715- if (inflight != percpu_counter_sum ( & tctx -> inflight ))
8768+ if (inflight != tctx_inflight ( tctx ))
87168769 continue ;
87178770 schedule ();
87188771 } while (1 );
87198772
87208773 finish_wait (& tctx -> wait , & wait );
8721- tctx -> in_idle = false ;
8774+ atomic_dec ( & tctx -> in_idle ) ;
87228775}
87238776
87248777static int io_uring_flush (struct file * file , void * data )
@@ -8863,7 +8916,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
88638916 io_sqpoll_wait_sq (ctx );
88648917 submitted = to_submit ;
88658918 } else if (to_submit ) {
8866- ret = io_uring_add_task_file (f .file );
8919+ ret = io_uring_add_task_file (ctx , f .file );
88678920 if (unlikely (ret ))
88688921 goto out ;
88698922 mutex_lock (& ctx -> uring_lock );
@@ -9092,7 +9145,7 @@ static int io_uring_get_fd(struct io_ring_ctx *ctx)
90929145#if defined(CONFIG_UNIX )
90939146 ctx -> ring_sock -> file = file ;
90949147#endif
9095- if (unlikely (io_uring_add_task_file (file ))) {
9148+ if (unlikely (io_uring_add_task_file (ctx , file ))) {
90969149 file = ERR_PTR (- ENOMEM );
90979150 goto err_fd ;
90989151 }
0 commit comments