x86/mm/pat: Fix off-by-one bugs in interval tree search
[linux-2.6-microblaze.git] / arch / x86 / mm / pat_interval.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Handle caching attributes in page tables (PAT)
4  *
5  * Authors: Venkatesh Pallipadi <venkatesh.pallipadi@intel.com>
6  *          Suresh B Siddha <suresh.b.siddha@intel.com>
7  *
8  * Interval tree used to store the PAT memory type reservations.
9  */
10
11 #include <linux/seq_file.h>
12 #include <linux/debugfs.h>
13 #include <linux/kernel.h>
14 #include <linux/interval_tree_generic.h>
15 #include <linux/sched.h>
16 #include <linux/gfp.h>
17
18 #include <asm/pgtable.h>
19 #include <asm/pat.h>
20
21 #include "pat_internal.h"
22
23 /*
24  * The memtype tree keeps track of memory type for specific
25  * physical memory areas. Without proper tracking, conflicting memory
26  * types in different mappings can cause CPU cache corruption.
27  *
28  * The tree is an interval tree (augmented rbtree) with tree ordered
29  * on starting address. Tree can contain multiple entries for
30  * different regions which overlap. All the aliases have the same
31  * cache attributes of course.
32  *
33  * memtype_lock protects the rbtree.
34  */
35 static inline u64 memtype_interval_start(struct memtype *memtype)
36 {
37         return memtype->start;
38 }
39
40 static inline u64 memtype_interval_end(struct memtype *memtype)
41 {
42         return memtype->end - 1;
43 }
44 INTERVAL_TREE_DEFINE(struct memtype, rb, u64, subtree_max_end,
45                      memtype_interval_start, memtype_interval_end,
46                      static, memtype_interval)
47
48 static struct rb_root_cached memtype_rbroot = RB_ROOT_CACHED;
49
50 enum {
51         MEMTYPE_EXACT_MATCH     = 0,
52         MEMTYPE_END_MATCH       = 1
53 };
54
55 static struct memtype *memtype_match(u64 start, u64 end, int match_type)
56 {
57         struct memtype *match;
58
59         match = memtype_interval_iter_first(&memtype_rbroot, start, end-1);
60         while (match != NULL && match->start < end) {
61                 if ((match_type == MEMTYPE_EXACT_MATCH) &&
62                     (match->start == start) && (match->end == end))
63                         return match;
64
65                 if ((match_type == MEMTYPE_END_MATCH) &&
66                     (match->start < start) && (match->end == end))
67                         return match;
68
69                 match = memtype_interval_iter_next(match, start, end-1);
70         }
71
72         return NULL; /* Returns NULL if there is no match */
73 }
74
75 static int memtype_check_conflict(u64 start, u64 end,
76                                   enum page_cache_mode reqtype,
77                                   enum page_cache_mode *newtype)
78 {
79         struct memtype *match;
80         enum page_cache_mode found_type = reqtype;
81
82         match = memtype_interval_iter_first(&memtype_rbroot, start, end-1);
83         if (match == NULL)
84                 goto success;
85
86         if (match->type != found_type && newtype == NULL)
87                 goto failure;
88
89         dprintk("Overlap at 0x%Lx-0x%Lx\n", match->start, match->end);
90         found_type = match->type;
91
92         match = memtype_interval_iter_next(match, start, end-1);
93         while (match) {
94                 if (match->type != found_type)
95                         goto failure;
96
97                 match = memtype_interval_iter_next(match, start, end-1);
98         }
99 success:
100         if (newtype)
101                 *newtype = found_type;
102
103         return 0;
104
105 failure:
106         pr_info("x86/PAT: %s:%d conflicting memory types %Lx-%Lx %s<->%s\n",
107                 current->comm, current->pid, start, end,
108                 cattr_name(found_type), cattr_name(match->type));
109         return -EBUSY;
110 }
111
112 int memtype_check_insert(struct memtype *new,
113                          enum page_cache_mode *ret_type)
114 {
115         int err = 0;
116
117         err = memtype_check_conflict(new->start, new->end, new->type, ret_type);
118         if (err)
119                 return err;
120
121         if (ret_type)
122                 new->type = *ret_type;
123
124         memtype_interval_insert(new, &memtype_rbroot);
125         return 0;
126 }
127
128 struct memtype *memtype_erase(u64 start, u64 end)
129 {
130         struct memtype *data;
131
132         /*
133          * Since the memtype_rbroot tree allows overlapping ranges,
134          * memtype_erase() checks with EXACT_MATCH first, i.e. free
135          * a whole node for the munmap case.  If no such entry is found,
136          * it then checks with END_MATCH, i.e. shrink the size of a node
137          * from the end for the mremap case.
138          */
139         data = memtype_match(start, end, MEMTYPE_EXACT_MATCH);
140         if (!data) {
141                 data = memtype_match(start, end, MEMTYPE_END_MATCH);
142                 if (!data)
143                         return ERR_PTR(-EINVAL);
144         }
145
146         if (data->start == start) {
147                 /* munmap: erase this node */
148                 memtype_interval_remove(data, &memtype_rbroot);
149         } else {
150                 /* mremap: update the end value of this node */
151                 memtype_interval_remove(data, &memtype_rbroot);
152                 data->end = start;
153                 memtype_interval_insert(data, &memtype_rbroot);
154                 return NULL;
155         }
156
157         return data;
158 }
159
160 struct memtype *memtype_lookup(u64 addr)
161 {
162         return memtype_interval_iter_first(&memtype_rbroot, addr,
163                                            addr + PAGE_SIZE-1);
164 }
165
166 #if defined(CONFIG_DEBUG_FS)
167 int memtype_copy_nth_element(struct memtype *out, loff_t pos)
168 {
169         struct memtype *match;
170         int i = 1;
171
172         match = memtype_interval_iter_first(&memtype_rbroot, 0, ULONG_MAX);
173         while (match && pos != i) {
174                 match = memtype_interval_iter_next(match, 0, ULONG_MAX);
175                 i++;
176         }
177
178         if (match) { /* pos == i */
179                 *out = *match;
180                 return 0;
181         } else {
182                 return 1;
183         }
184 }
185 #endif