Merge tag 'for-net-next-2021-08-19' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-microblaze.git] / arch / arm64 / lib / csum.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
3
4 #include <linux/compiler.h>
5 #include <linux/kasan-checks.h>
6 #include <linux/kernel.h>
7
8 #include <net/checksum.h>
9
10 /* Looks dumb, but generates nice-ish code */
11 static u64 accumulate(u64 sum, u64 data)
12 {
13         __uint128_t tmp = (__uint128_t)sum + data;
14         return tmp + (tmp >> 64);
15 }
16
17 /*
18  * We over-read the buffer and this makes KASAN unhappy. Instead, disable
19  * instrumentation and call kasan explicitly.
20  */
21 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
22 {
23         unsigned int offset, shift, sum;
24         const u64 *ptr;
25         u64 data, sum64 = 0;
26
27         if (unlikely(len == 0))
28                 return 0;
29
30         offset = (unsigned long)buff & 7;
31         /*
32          * This is to all intents and purposes safe, since rounding down cannot
33          * result in a different page or cache line being accessed, and @buff
34          * should absolutely not be pointing to anything read-sensitive. We do,
35          * however, have to be careful not to piss off KASAN, which means using
36          * unchecked reads to accommodate the head and tail, for which we'll
37          * compensate with an explicit check up-front.
38          */
39         kasan_check_read(buff, len);
40         ptr = (u64 *)(buff - offset);
41         len = len + offset - 8;
42
43         /*
44          * Head: zero out any excess leading bytes. Shifting back by the same
45          * amount should be at least as fast as any other way of handling the
46          * odd/even alignment, and means we can ignore it until the very end.
47          */
48         shift = offset * 8;
49         data = *ptr++;
50 #ifdef __LITTLE_ENDIAN
51         data = (data >> shift) << shift;
52 #else
53         data = (data << shift) >> shift;
54 #endif
55
56         /*
57          * Body: straightforward aligned loads from here on (the paired loads
58          * underlying the quadword type still only need dword alignment). The
59          * main loop strictly excludes the tail, so the second loop will always
60          * run at least once.
61          */
62         while (unlikely(len > 64)) {
63                 __uint128_t tmp1, tmp2, tmp3, tmp4;
64
65                 tmp1 = *(__uint128_t *)ptr;
66                 tmp2 = *(__uint128_t *)(ptr + 2);
67                 tmp3 = *(__uint128_t *)(ptr + 4);
68                 tmp4 = *(__uint128_t *)(ptr + 6);
69
70                 len -= 64;
71                 ptr += 8;
72
73                 /* This is the "don't dump the carry flag into a GPR" idiom */
74                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
75                 tmp2 += (tmp2 >> 64) | (tmp2 << 64);
76                 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
77                 tmp4 += (tmp4 >> 64) | (tmp4 << 64);
78                 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
79                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
80                 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
81                 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
82                 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
83                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
84                 tmp1 = ((tmp1 >> 64) << 64) | sum64;
85                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
86                 sum64 = tmp1 >> 64;
87         }
88         while (len > 8) {
89                 __uint128_t tmp;
90
91                 sum64 = accumulate(sum64, data);
92                 tmp = *(__uint128_t *)ptr;
93
94                 len -= 16;
95                 ptr += 2;
96
97 #ifdef __LITTLE_ENDIAN
98                 data = tmp >> 64;
99                 sum64 = accumulate(sum64, tmp);
100 #else
101                 data = tmp;
102                 sum64 = accumulate(sum64, tmp >> 64);
103 #endif
104         }
105         if (len > 0) {
106                 sum64 = accumulate(sum64, data);
107                 data = *ptr;
108                 len -= 8;
109         }
110         /*
111          * Tail: zero any over-read bytes similarly to the head, again
112          * preserving odd/even alignment.
113          */
114         shift = len * -8;
115 #ifdef __LITTLE_ENDIAN
116         data = (data << shift) >> shift;
117 #else
118         data = (data >> shift) << shift;
119 #endif
120         sum64 = accumulate(sum64, data);
121
122         /* Finally, folding */
123         sum64 += (sum64 >> 32) | (sum64 << 32);
124         sum = sum64 >> 32;
125         sum += (sum >> 16) | (sum << 16);
126         if (offset & 1)
127                 return (u16)swab32(sum);
128
129         return sum >> 16;
130 }
131
132 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
133                         const struct in6_addr *daddr,
134                         __u32 len, __u8 proto, __wsum csum)
135 {
136         __uint128_t src, dst;
137         u64 sum = (__force u64)csum;
138
139         src = *(const __uint128_t *)saddr->s6_addr;
140         dst = *(const __uint128_t *)daddr->s6_addr;
141
142         sum += (__force u32)htonl(len);
143 #ifdef __LITTLE_ENDIAN
144         sum += (u32)proto << 24;
145 #else
146         sum += proto;
147 #endif
148         src += (src >> 64) | (src << 64);
149         dst += (dst >> 64) | (dst << 64);
150
151         sum = accumulate(sum, src >> 64);
152         sum = accumulate(sum, dst >> 64);
153
154         sum += ((sum >> 32) | (sum << 32));
155         return csum_fold((__force __wsum)(sum >> 32));
156 }
157 EXPORT_SYMBOL(csum_ipv6_magic);