Linux 6.9-rc1
[linux-2.6-microblaze.git] / mm / page_ext.c
index 3dc715d..4548fcc 100644 (file)
@@ -9,6 +9,7 @@
 #include <linux/page_owner.h>
 #include <linux/page_idle.h>
 #include <linux/page_table_check.h>
+#include <linux/rcupdate.h>
 
 /*
  * struct page extension
  * can utilize this callback to initialize the state of it correctly.
  */
 
+#ifdef CONFIG_SPARSEMEM
+#define PAGE_EXT_INVALID       (0x1)
+#endif
+
 #if defined(CONFIG_PAGE_IDLE_FLAG) && !defined(CONFIG_64BIT)
 static bool need_page_idle(void)
 {
@@ -66,6 +71,7 @@ static bool need_page_idle(void)
 }
 static struct page_ext_operations page_idle_ops __initdata = {
        .need = need_page_idle,
+       .need_shared_flags = true,
 };
 #endif
 
@@ -81,10 +87,18 @@ static struct page_ext_operations *page_ext_ops[] __initdata = {
 #endif
 };
 
-unsigned long page_ext_size = sizeof(struct page_ext);
+unsigned long page_ext_size;
 
 static unsigned long total_usage;
 
+bool early_page_ext __meminitdata;
+static int __init setup_early_page_ext(char *str)
+{
+       early_page_ext = true;
+       return 0;
+}
+early_param("early_page_ext", setup_early_page_ext);
+
 static bool __init invoke_need_callbacks(void)
 {
        int i;
@@ -92,7 +106,16 @@ static bool __init invoke_need_callbacks(void)
        bool need = false;
 
        for (i = 0; i < entries; i++) {
-               if (page_ext_ops[i]->need && page_ext_ops[i]->need()) {
+               if (page_ext_ops[i]->need()) {
+                       if (page_ext_ops[i]->need_shared_flags) {
+                               page_ext_size = sizeof(struct page_ext);
+                               break;
+                       }
+               }
+       }
+
+       for (i = 0; i < entries; i++) {
+               if (page_ext_ops[i]->need()) {
                        page_ext_ops[i]->offset = page_ext_size;
                        page_ext_size += page_ext_ops[i]->size;
                        need = true;
@@ -113,32 +136,29 @@ static void __init invoke_init_callbacks(void)
        }
 }
 
-#ifndef CONFIG_SPARSEMEM
-void __init page_ext_init_flatmem_late(void)
-{
-       invoke_init_callbacks();
-}
-#endif
-
 static inline struct page_ext *get_entry(void *base, unsigned long index)
 {
        return base + page_ext_size * index;
 }
 
 #ifndef CONFIG_SPARSEMEM
-
+void __init page_ext_init_flatmem_late(void)
+{
+       invoke_init_callbacks();
+}
 
 void __meminit pgdat_page_ext_init(struct pglist_data *pgdat)
 {
        pgdat->node_page_ext = NULL;
 }
 
-struct page_ext *lookup_page_ext(const struct page *page)
+static struct page_ext *lookup_page_ext(const struct page *page)
 {
        unsigned long pfn = page_to_pfn(page);
        unsigned long index;
        struct page_ext *base;
 
+       WARN_ON_ONCE(!rcu_read_lock_held());
        base = NODE_DATA(page_to_nid(page))->node_page_ext;
        /*
         * The sanity checks the page allocator does upon freeing a
@@ -206,20 +226,27 @@ fail:
 }
 
 #else /* CONFIG_SPARSEMEM */
+static bool page_ext_invalid(struct page_ext *page_ext)
+{
+       return !page_ext || (((unsigned long)page_ext & PAGE_EXT_INVALID) == PAGE_EXT_INVALID);
+}
 
-struct page_ext *lookup_page_ext(const struct page *page)
+static struct page_ext *lookup_page_ext(const struct page *page)
 {
        unsigned long pfn = page_to_pfn(page);
        struct mem_section *section = __pfn_to_section(pfn);
+       struct page_ext *page_ext = READ_ONCE(section->page_ext);
+
+       WARN_ON_ONCE(!rcu_read_lock_held());
        /*
         * The sanity checks the page allocator does upon freeing a
         * page can reach here before the page_ext arrays are
         * allocated when feeding a range of pages to the allocator
         * for the first time during bootup or memory hotplug.
         */
-       if (!section->page_ext)
+       if (page_ext_invalid(page_ext))
                return NULL;
-       return get_entry(section->page_ext, pfn);
+       return get_entry(page_ext, pfn);
 }
 
 static void *__meminit alloc_page_ext(size_t size, int nid)
@@ -298,9 +325,30 @@ static void __free_page_ext(unsigned long pfn)
        ms = __pfn_to_section(pfn);
        if (!ms || !ms->page_ext)
                return;
-       base = get_entry(ms->page_ext, pfn);
+
+       base = READ_ONCE(ms->page_ext);
+       /*
+        * page_ext here can be valid while doing the roll back
+        * operation in online_page_ext().
+        */
+       if (page_ext_invalid(base))
+               base = (void *)base - PAGE_EXT_INVALID;
+       WRITE_ONCE(ms->page_ext, NULL);
+
+       base = get_entry(base, pfn);
        free_page_ext(base);
-       ms->page_ext = NULL;
+}
+
+static void __invalidate_page_ext(unsigned long pfn)
+{
+       struct mem_section *ms;
+       void *val;
+
+       ms = __pfn_to_section(pfn);
+       if (!ms || !ms->page_ext)
+               return;
+       val = (void *)ms->page_ext + PAGE_EXT_INVALID;
+       WRITE_ONCE(ms->page_ext, val);
 }
 
 static int __meminit online_page_ext(unsigned long start_pfn,
@@ -329,24 +377,37 @@ static int __meminit online_page_ext(unsigned long start_pfn,
                return 0;
 
        /* rollback */
+       end = pfn - PAGES_PER_SECTION;
        for (pfn = start; pfn < end; pfn += PAGES_PER_SECTION)
                __free_page_ext(pfn);
 
        return -ENOMEM;
 }
 
-static int __meminit offline_page_ext(unsigned long start_pfn,
-                               unsigned long nr_pages, int nid)
+static void __meminit offline_page_ext(unsigned long start_pfn,
+                               unsigned long nr_pages)
 {
        unsigned long start, end, pfn;
 
        start = SECTION_ALIGN_DOWN(start_pfn);
        end = SECTION_ALIGN_UP(start_pfn + nr_pages);
 
+       /*
+        * Freeing of page_ext is done in 3 steps to avoid
+        * use-after-free of it:
+        * 1) Traverse all the sections and mark their page_ext
+        *    as invalid.
+        * 2) Wait for all the existing users of page_ext who
+        *    started before invalidation to finish.
+        * 3) Free the page_ext.
+        */
+       for (pfn = start; pfn < end; pfn += PAGES_PER_SECTION)
+               __invalidate_page_ext(pfn);
+
+       synchronize_rcu();
+
        for (pfn = start; pfn < end; pfn += PAGES_PER_SECTION)
                __free_page_ext(pfn);
-       return 0;
-
 }
 
 static int __meminit page_ext_callback(struct notifier_block *self,
@@ -362,11 +423,11 @@ static int __meminit page_ext_callback(struct notifier_block *self,
                break;
        case MEM_OFFLINE:
                offline_page_ext(mn->start_pfn,
-                               mn->nr_pages, mn->status_change_nid);
+                               mn->nr_pages);
                break;
        case MEM_CANCEL_ONLINE:
                offline_page_ext(mn->start_pfn,
-                               mn->nr_pages, mn->status_change_nid);
+                               mn->nr_pages);
                break;
        case MEM_GOING_OFFLINE:
                break;
@@ -414,7 +475,7 @@ void __init page_ext_init(void)
                        cond_resched();
                }
        }
-       hotplug_memory_notifier(page_ext_callback, 0);
+       hotplug_memory_notifier(page_ext_callback, DEFAULT_CALLBACK_PRI);
        pr_info("allocated %ld bytes of page_ext\n", total_usage);
        invoke_init_callbacks();
        return;
@@ -428,3 +489,46 @@ void __meminit pgdat_page_ext_init(struct pglist_data *pgdat)
 }
 
 #endif
+
+/**
+ * page_ext_get() - Get the extended information for a page.
+ * @page: The page we're interested in.
+ *
+ * Ensures that the page_ext will remain valid until page_ext_put()
+ * is called.
+ *
+ * Return: NULL if no page_ext exists for this page.
+ * Context: Any context.  Caller may not sleep until they have called
+ * page_ext_put().
+ */
+struct page_ext *page_ext_get(struct page *page)
+{
+       struct page_ext *page_ext;
+
+       rcu_read_lock();
+       page_ext = lookup_page_ext(page);
+       if (!page_ext) {
+               rcu_read_unlock();
+               return NULL;
+       }
+
+       return page_ext;
+}
+
+/**
+ * page_ext_put() - Working with page extended information is done.
+ * @page_ext: Page extended information received from page_ext_get().
+ *
+ * The page extended information of the page may not be valid after this
+ * function is called.
+ *
+ * Return: None.
+ * Context: Any context with corresponding page_ext_get() is called.
+ */
+void page_ext_put(struct page_ext *page_ext)
+{
+       if (unlikely(!page_ext))
+               return;
+
+       rcu_read_unlock();
+}