extent = next - old_addr;
                if (extent > old_end - old_addr)
                        extent = old_end - old_addr;
+               next = (new_addr + PMD_SIZE) & PMD_MASK;
+               if (extent > next - new_addr)
+                       extent = next - new_addr;
                old_pmd = get_old_pmd(vma->vm_mm, old_addr);
                if (!old_pmd)
                        continue;
 
                if (pte_alloc(new_vma->vm_mm, new_pmd))
                        break;
-               next = (new_addr + PMD_SIZE) & PMD_MASK;
-               if (extent > next - new_addr)
-                       extent = next - new_addr;
                move_ptes(vma, old_pmd, old_addr, old_addr + extent, new_vma,
                          new_pmd, new_addr, need_rmap_locks);
        }