@@ -995,20 +995,33 @@ static void io_sq_thread_drop_mm(void)
995995 if (mm ) {
996996 kthread_unuse_mm (mm );
997997 mmput (mm );
998+ current -> mm = NULL ;
998999 }
9991000}
10001001
10011002static int __io_sq_thread_acquire_mm (struct io_ring_ctx * ctx )
10021003{
1003- if (!current -> mm ) {
1004- if (unlikely (!(ctx -> flags & IORING_SETUP_SQPOLL ) ||
1005- !ctx -> sqo_task -> mm ||
1006- !mmget_not_zero (ctx -> sqo_task -> mm )))
1007- return - EFAULT ;
1008- kthread_use_mm (ctx -> sqo_task -> mm );
1004+ struct mm_struct * mm ;
1005+
1006+ if (current -> mm )
1007+ return 0 ;
1008+
1009+ /* Should never happen */
1010+ if (unlikely (!(ctx -> flags & IORING_SETUP_SQPOLL )))
1011+ return - EFAULT ;
1012+
1013+ task_lock (ctx -> sqo_task );
1014+ mm = ctx -> sqo_task -> mm ;
1015+ if (unlikely (!mm || !mmget_not_zero (mm )))
1016+ mm = NULL ;
1017+ task_unlock (ctx -> sqo_task );
1018+
1019+ if (mm ) {
1020+ kthread_use_mm (mm );
1021+ return 0 ;
10091022 }
10101023
1011- return 0 ;
1024+ return - EFAULT ;
10121025}
10131026
10141027static int io_sq_thread_acquire_mm (struct io_ring_ctx * ctx ,
@@ -1274,9 +1287,12 @@ static bool io_identity_cow(struct io_kiocb *req)
12741287 /* add one for this request */
12751288 refcount_inc (& id -> count );
12761289
1277- /* drop old identity, assign new one. one ref for req, one for tctx */
1278- if (req -> work .identity != tctx -> identity &&
1279- refcount_sub_and_test (2 , & req -> work .identity -> count ))
1290+ /* drop tctx and req identity references, if needed */
1291+ if (tctx -> identity != & tctx -> __identity &&
1292+ refcount_dec_and_test (& tctx -> identity -> count ))
1293+ kfree (tctx -> identity );
1294+ if (req -> work .identity != & tctx -> __identity &&
1295+ refcount_dec_and_test (& req -> work .identity -> count ))
12801296 kfree (req -> work .identity );
12811297
12821298 req -> work .identity = id ;
@@ -1577,14 +1593,29 @@ static void io_cqring_mark_overflow(struct io_ring_ctx *ctx)
15771593 }
15781594}
15791595
1580- static inline bool io_match_files (struct io_kiocb * req ,
1581- struct files_struct * files )
1596+ static inline bool __io_match_files (struct io_kiocb * req ,
1597+ struct files_struct * files )
15821598{
1599+ return ((req -> flags & REQ_F_WORK_INITIALIZED ) &&
1600+ (req -> work .flags & IO_WQ_WORK_FILES )) &&
1601+ req -> work .identity -> files == files ;
1602+ }
1603+
1604+ static bool io_match_files (struct io_kiocb * req ,
1605+ struct files_struct * files )
1606+ {
1607+ struct io_kiocb * link ;
1608+
15831609 if (!files )
15841610 return true;
1585- if ((req -> flags & REQ_F_WORK_INITIALIZED ) &&
1586- (req -> work .flags & IO_WQ_WORK_FILES ))
1587- return req -> work .identity -> files == files ;
1611+ if (__io_match_files (req , files ))
1612+ return true;
1613+ if (req -> flags & REQ_F_LINK_HEAD ) {
1614+ list_for_each_entry (link , & req -> link_list , link_list ) {
1615+ if (__io_match_files (link , files ))
1616+ return true;
1617+ }
1618+ }
15881619 return false;
15891620}
15901621
@@ -1668,7 +1699,8 @@ static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
16681699 WRITE_ONCE (cqe -> user_data , req -> user_data );
16691700 WRITE_ONCE (cqe -> res , res );
16701701 WRITE_ONCE (cqe -> flags , cflags );
1671- } else if (ctx -> cq_overflow_flushed || req -> task -> io_uring -> in_idle ) {
1702+ } else if (ctx -> cq_overflow_flushed ||
1703+ atomic_read (& req -> task -> io_uring -> in_idle )) {
16721704 /*
16731705 * If we're in ring overflow flush mode, or in task cancel mode,
16741706 * then we cannot store the request for later flushing, we need
@@ -1838,7 +1870,7 @@ static void __io_free_req(struct io_kiocb *req)
18381870 io_dismantle_req (req );
18391871
18401872 percpu_counter_dec (& tctx -> inflight );
1841- if (tctx -> in_idle )
1873+ if (atomic_read ( & tctx -> in_idle ) )
18421874 wake_up (& tctx -> wait );
18431875 put_task_struct (req -> task );
18441876
@@ -7695,7 +7727,8 @@ static int io_uring_alloc_task_context(struct task_struct *task)
76957727 xa_init (& tctx -> xa );
76967728 init_waitqueue_head (& tctx -> wait );
76977729 tctx -> last = NULL ;
7698- tctx -> in_idle = 0 ;
7730+ atomic_set (& tctx -> in_idle , 0 );
7731+ tctx -> sqpoll = false;
76997732 io_init_identity (& tctx -> __identity );
77007733 tctx -> identity = & tctx -> __identity ;
77017734 task -> io_uring = tctx ;
@@ -8388,22 +8421,6 @@ static bool io_match_link(struct io_kiocb *preq, struct io_kiocb *req)
83888421 return false;
83898422}
83908423
8391- static bool io_match_link_files (struct io_kiocb * req ,
8392- struct files_struct * files )
8393- {
8394- struct io_kiocb * link ;
8395-
8396- if (io_match_files (req , files ))
8397- return true;
8398- if (req -> flags & REQ_F_LINK_HEAD ) {
8399- list_for_each_entry (link , & req -> link_list , link_list ) {
8400- if (io_match_files (link , files ))
8401- return true;
8402- }
8403- }
8404- return false;
8405- }
8406-
84078424/*
84088425 * We're looking to cancel 'req' because it's holding on to our files, but
84098426 * 'req' could be a link to another request. See if it is, and cancel that
@@ -8453,7 +8470,21 @@ static bool io_timeout_remove_link(struct io_ring_ctx *ctx,
84538470
84548471static bool io_cancel_link_cb (struct io_wq_work * work , void * data )
84558472{
8456- return io_match_link (container_of (work , struct io_kiocb , work ), data );
8473+ struct io_kiocb * req = container_of (work , struct io_kiocb , work );
8474+ bool ret ;
8475+
8476+ if (req -> flags & REQ_F_LINK_TIMEOUT ) {
8477+ unsigned long flags ;
8478+ struct io_ring_ctx * ctx = req -> ctx ;
8479+
8480+ /* protect against races with linked timeouts */
8481+ spin_lock_irqsave (& ctx -> completion_lock , flags );
8482+ ret = io_match_link (req , data );
8483+ spin_unlock_irqrestore (& ctx -> completion_lock , flags );
8484+ } else {
8485+ ret = io_match_link (req , data );
8486+ }
8487+ return ret ;
84578488}
84588489
84598490static void io_attempt_cancel (struct io_ring_ctx * ctx , struct io_kiocb * req )
@@ -8479,14 +8510,16 @@ static void io_attempt_cancel(struct io_ring_ctx *ctx, struct io_kiocb *req)
84798510}
84808511
84818512static void io_cancel_defer_files (struct io_ring_ctx * ctx ,
8513+ struct task_struct * task ,
84828514 struct files_struct * files )
84838515{
84848516 struct io_defer_entry * de = NULL ;
84858517 LIST_HEAD (list );
84868518
84878519 spin_lock_irq (& ctx -> completion_lock );
84888520 list_for_each_entry_reverse (de , & ctx -> defer_list , list ) {
8489- if (io_match_link_files (de -> req , files )) {
8521+ if (io_task_match (de -> req , task ) &&
8522+ io_match_files (de -> req , files )) {
84908523 list_cut_position (& list , & ctx -> defer_list , & de -> list );
84918524 break ;
84928525 }
@@ -8512,7 +8545,6 @@ static bool io_uring_cancel_files(struct io_ring_ctx *ctx,
85128545 if (list_empty_careful (& ctx -> inflight_list ))
85138546 return false;
85148547
8515- io_cancel_defer_files (ctx , files );
85168548 /* cancel all at once, should be faster than doing it one by one*/
85178549 io_wq_cancel_cb (ctx -> io_wq , io_wq_files_match , files , true);
85188550
@@ -8598,21 +8630,40 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
85988630{
85998631 struct task_struct * task = current ;
86008632
8601- if ((ctx -> flags & IORING_SETUP_SQPOLL ) && ctx -> sq_data )
8633+ if ((ctx -> flags & IORING_SETUP_SQPOLL ) && ctx -> sq_data ) {
86028634 task = ctx -> sq_data -> thread ;
8635+ atomic_inc (& task -> io_uring -> in_idle );
8636+ io_sq_thread_park (ctx -> sq_data );
8637+ }
8638+
8639+ if (files )
8640+ io_cancel_defer_files (ctx , NULL , files );
8641+ else
8642+ io_cancel_defer_files (ctx , task , NULL );
86038643
86048644 io_cqring_overflow_flush (ctx , true, task , files );
86058645
86068646 while (__io_uring_cancel_task_requests (ctx , task , files )) {
86078647 io_run_task_work ();
86088648 cond_resched ();
86098649 }
8650+
8651+ if ((ctx -> flags & IORING_SETUP_SQPOLL ) && ctx -> sq_data ) {
8652+ atomic_dec (& task -> io_uring -> in_idle );
8653+ /*
8654+ * If the files that are going away are the ones in the thread
8655+ * identity, clear them out.
8656+ */
8657+ if (task -> io_uring -> identity -> files == files )
8658+ task -> io_uring -> identity -> files = NULL ;
8659+ io_sq_thread_unpark (ctx -> sq_data );
8660+ }
86108661}
86118662
86128663/*
86138664 * Note that this task has used io_uring. We use it for cancelation purposes.
86148665 */
8615- static int io_uring_add_task_file (struct file * file )
8666+ static int io_uring_add_task_file (struct io_ring_ctx * ctx , struct file * file )
86168667{
86178668 struct io_uring_task * tctx = current -> io_uring ;
86188669
@@ -8634,6 +8685,14 @@ static int io_uring_add_task_file(struct file *file)
86348685 tctx -> last = file ;
86358686 }
86368687
8688+ /*
8689+ * This is race safe in that the task itself is doing this, hence it
8690+ * cannot be going through the exit/cancel paths at the same time.
8691+ * This cannot be modified while exit/cancel is running.
8692+ */
8693+ if (!tctx -> sqpoll && (ctx -> flags & IORING_SETUP_SQPOLL ))
8694+ tctx -> sqpoll = true;
8695+
86378696 return 0 ;
86388697}
86398698
@@ -8675,7 +8734,7 @@ void __io_uring_files_cancel(struct files_struct *files)
86758734 unsigned long index ;
86768735
86778736 /* make sure overflow events are dropped */
8678- tctx -> in_idle = true ;
8737+ atomic_inc ( & tctx -> in_idle ) ;
86798738
86808739 xa_for_each (& tctx -> xa , index , file ) {
86818740 struct io_ring_ctx * ctx = file -> private_data ;
@@ -8684,6 +8743,35 @@ void __io_uring_files_cancel(struct files_struct *files)
86848743 if (files )
86858744 io_uring_del_task_file (file );
86868745 }
8746+
8747+ atomic_dec (& tctx -> in_idle );
8748+ }
8749+
8750+ static s64 tctx_inflight (struct io_uring_task * tctx )
8751+ {
8752+ unsigned long index ;
8753+ struct file * file ;
8754+ s64 inflight ;
8755+
8756+ inflight = percpu_counter_sum (& tctx -> inflight );
8757+ if (!tctx -> sqpoll )
8758+ return inflight ;
8759+
8760+ /*
8761+ * If we have SQPOLL rings, then we need to iterate and find them, and
8762+ * add the pending count for those.
8763+ */
8764+ xa_for_each (& tctx -> xa , index , file ) {
8765+ struct io_ring_ctx * ctx = file -> private_data ;
8766+
8767+ if (ctx -> flags & IORING_SETUP_SQPOLL ) {
8768+ struct io_uring_task * __tctx = ctx -> sqo_task -> io_uring ;
8769+
8770+ inflight += percpu_counter_sum (& __tctx -> inflight );
8771+ }
8772+ }
8773+
8774+ return inflight ;
86878775}
86888776
86898777/*
@@ -8697,11 +8785,11 @@ void __io_uring_task_cancel(void)
86978785 s64 inflight ;
86988786
86998787 /* make sure overflow events are dropped */
8700- tctx -> in_idle = true ;
8788+ atomic_inc ( & tctx -> in_idle ) ;
87018789
87028790 do {
87038791 /* read completions before cancelations */
8704- inflight = percpu_counter_sum ( & tctx -> inflight );
8792+ inflight = tctx_inflight ( tctx );
87058793 if (!inflight )
87068794 break ;
87078795 __io_uring_files_cancel (NULL );
@@ -8712,13 +8800,13 @@ void __io_uring_task_cancel(void)
87128800 * If we've seen completions, retry. This avoids a race where
87138801 * a completion comes in before we did prepare_to_wait().
87148802 */
8715- if (inflight != percpu_counter_sum ( & tctx -> inflight ))
8803+ if (inflight != tctx_inflight ( tctx ))
87168804 continue ;
87178805 schedule ();
87188806 } while (1 );
87198807
87208808 finish_wait (& tctx -> wait , & wait );
8721- tctx -> in_idle = false ;
8809+ atomic_dec ( & tctx -> in_idle ) ;
87228810}
87238811
87248812static int io_uring_flush (struct file * file , void * data )
@@ -8863,7 +8951,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
88638951 io_sqpoll_wait_sq (ctx );
88648952 submitted = to_submit ;
88658953 } else if (to_submit ) {
8866- ret = io_uring_add_task_file (f .file );
8954+ ret = io_uring_add_task_file (ctx , f .file );
88678955 if (unlikely (ret ))
88688956 goto out ;
88698957 mutex_lock (& ctx -> uring_lock );
@@ -8900,7 +8988,8 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
89008988#ifdef CONFIG_PROC_FS
89018989static int io_uring_show_cred (int id , void * p , void * data )
89028990{
8903- const struct cred * cred = p ;
8991+ struct io_identity * iod = p ;
8992+ const struct cred * cred = iod -> creds ;
89048993 struct seq_file * m = data ;
89058994 struct user_namespace * uns = seq_user_ns (m );
89068995 struct group_info * gi ;
@@ -9092,7 +9181,7 @@ static int io_uring_get_fd(struct io_ring_ctx *ctx)
90929181#if defined(CONFIG_UNIX )
90939182 ctx -> ring_sock -> file = file ;
90949183#endif
9095- if (unlikely (io_uring_add_task_file (file ))) {
9184+ if (unlikely (io_uring_add_task_file (ctx , file ))) {
90969185 file = ERR_PTR (- ENOMEM );
90979186 goto err_fd ;
90989187 }
0 commit comments