@@ -577,6 +577,8 @@ static int vhost_vdpa_map(struct vhost_vdpa *v,
577577
578578 if (r )
579579 vhost_iotlb_del_range (dev -> iotlb , iova , iova + size - 1 );
580+ else
581+ atomic64_add (size >> PAGE_SHIFT , & dev -> mm -> pinned_vm );
580582
581583 return r ;
582584}
@@ -608,8 +610,9 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
608610 unsigned long list_size = PAGE_SIZE / sizeof (struct page * );
609611 unsigned int gup_flags = FOLL_LONGTERM ;
610612 unsigned long npages , cur_base , map_pfn , last_pfn = 0 ;
611- unsigned long locked , lock_limit , pinned , i ;
613+ unsigned long lock_limit , sz2pin , nchunks , i ;
612614 u64 iova = msg -> iova ;
615+ long pinned ;
613616 int ret = 0 ;
614617
615618 if (msg -> iova < v -> range .first ||
@@ -620,6 +623,7 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
620623 msg -> iova + msg -> size - 1 ))
621624 return - EEXIST ;
622625
626+ /* Limit the use of memory for bookkeeping */
623627 page_list = (struct page * * ) __get_free_page (GFP_KERNEL );
624628 if (!page_list )
625629 return - ENOMEM ;
@@ -628,63 +632,103 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
628632 gup_flags |= FOLL_WRITE ;
629633
630634 npages = PAGE_ALIGN (msg -> size + (iova & ~PAGE_MASK )) >> PAGE_SHIFT ;
631- if (!npages )
632- return - EINVAL ;
635+ if (!npages ) {
636+ ret = - EINVAL ;
637+ goto free ;
638+ }
633639
634640 mmap_read_lock (dev -> mm );
635641
636- locked = atomic64_add_return (npages , & dev -> mm -> pinned_vm );
637642 lock_limit = rlimit (RLIMIT_MEMLOCK ) >> PAGE_SHIFT ;
638-
639- if (locked > lock_limit ) {
643+ if (npages + atomic64_read (& dev -> mm -> pinned_vm ) > lock_limit ) {
640644 ret = - ENOMEM ;
641- goto out ;
645+ goto unlock ;
642646 }
643647
644648 cur_base = msg -> uaddr & PAGE_MASK ;
645649 iova &= PAGE_MASK ;
650+ nchunks = 0 ;
646651
647652 while (npages ) {
648- pinned = min_t (unsigned long , npages , list_size );
649- ret = pin_user_pages (cur_base , pinned ,
650- gup_flags , page_list , NULL );
651- if (ret != pinned )
653+ sz2pin = min_t (unsigned long , npages , list_size );
654+ pinned = pin_user_pages (cur_base , sz2pin ,
655+ gup_flags , page_list , NULL );
656+ if (sz2pin != pinned ) {
657+ if (pinned < 0 ) {
658+ ret = pinned ;
659+ } else {
660+ unpin_user_pages (page_list , pinned );
661+ ret = - ENOMEM ;
662+ }
652663 goto out ;
664+ }
665+ nchunks ++ ;
653666
654667 if (!last_pfn )
655668 map_pfn = page_to_pfn (page_list [0 ]);
656669
657- for (i = 0 ; i < ret ; i ++ ) {
670+ for (i = 0 ; i < pinned ; i ++ ) {
658671 unsigned long this_pfn = page_to_pfn (page_list [i ]);
659672 u64 csize ;
660673
661674 if (last_pfn && (this_pfn != last_pfn + 1 )) {
662675 /* Pin a contiguous chunk of memory */
663676 csize = (last_pfn - map_pfn + 1 ) << PAGE_SHIFT ;
664- if (vhost_vdpa_map (v , iova , csize ,
665- map_pfn << PAGE_SHIFT ,
666- msg -> perm ))
677+ ret = vhost_vdpa_map (v , iova , csize ,
678+ map_pfn << PAGE_SHIFT ,
679+ msg -> perm );
680+ if (ret ) {
681+ /*
682+ * Unpin the pages that are left unmapped
683+ * from this point on in the current
684+ * page_list. The remaining outstanding
685+ * ones which may stride across several
686+ * chunks will be covered in the common
687+ * error path subsequently.
688+ */
689+ unpin_user_pages (& page_list [i ],
690+ pinned - i );
667691 goto out ;
692+ }
693+
668694 map_pfn = this_pfn ;
669695 iova += csize ;
696+ nchunks = 0 ;
670697 }
671698
672699 last_pfn = this_pfn ;
673700 }
674701
675- cur_base += ret << PAGE_SHIFT ;
676- npages -= ret ;
702+ cur_base += pinned << PAGE_SHIFT ;
703+ npages -= pinned ;
677704 }
678705
679706 /* Pin the rest chunk */
680707 ret = vhost_vdpa_map (v , iova , (last_pfn - map_pfn + 1 ) << PAGE_SHIFT ,
681708 map_pfn << PAGE_SHIFT , msg -> perm );
682709out :
683710 if (ret ) {
711+ if (nchunks ) {
712+ unsigned long pfn ;
713+
714+ /*
715+ * Unpin the outstanding pages which are yet to be
716+ * mapped but haven't due to vdpa_map() or
717+ * pin_user_pages() failure.
718+ *
719+ * Mapped pages are accounted in vdpa_map(), hence
720+ * the corresponding unpinning will be handled by
721+ * vdpa_unmap().
722+ */
723+ WARN_ON (!last_pfn );
724+ for (pfn = map_pfn ; pfn <= last_pfn ; pfn ++ )
725+ unpin_user_page (pfn_to_page (pfn ));
726+ }
684727 vhost_vdpa_unmap (v , msg -> iova , msg -> size );
685- atomic64_sub (npages , & dev -> mm -> pinned_vm );
686728 }
729+ unlock :
687730 mmap_read_unlock (dev -> mm );
731+ free :
688732 free_page ((unsigned long )page_list );
689733 return ret ;
690734}
0 commit comments