Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/dtor/input
[linux-2.6-microblaze.git] / arch / arm64 / lib / csum.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
3
4 #include <linux/compiler.h>
5 #include <linux/kasan-checks.h>
6 #include <linux/kernel.h>
7
8 #include <net/checksum.h>
9
10 /* Looks dumb, but generates nice-ish code */
11 static u64 accumulate(u64 sum, u64 data)
12 {
13         __uint128_t tmp = (__uint128_t)sum + data;
14         return tmp + (tmp >> 64);
15 }
16
17 unsigned int do_csum(const unsigned char *buff, int len)
18 {
19         unsigned int offset, shift, sum;
20         const u64 *ptr;
21         u64 data, sum64 = 0;
22
23         if (unlikely(len == 0))
24                 return 0;
25
26         offset = (unsigned long)buff & 7;
27         /*
28          * This is to all intents and purposes safe, since rounding down cannot
29          * result in a different page or cache line being accessed, and @buff
30          * should absolutely not be pointing to anything read-sensitive. We do,
31          * however, have to be careful not to piss off KASAN, which means using
32          * unchecked reads to accommodate the head and tail, for which we'll
33          * compensate with an explicit check up-front.
34          */
35         kasan_check_read(buff, len);
36         ptr = (u64 *)(buff - offset);
37         len = len + offset - 8;
38
39         /*
40          * Head: zero out any excess leading bytes. Shifting back by the same
41          * amount should be at least as fast as any other way of handling the
42          * odd/even alignment, and means we can ignore it until the very end.
43          */
44         shift = offset * 8;
45         data = READ_ONCE_NOCHECK(*ptr++);
46 #ifdef __LITTLE_ENDIAN
47         data = (data >> shift) << shift;
48 #else
49         data = (data << shift) >> shift;
50 #endif
51
52         /*
53          * Body: straightforward aligned loads from here on (the paired loads
54          * underlying the quadword type still only need dword alignment). The
55          * main loop strictly excludes the tail, so the second loop will always
56          * run at least once.
57          */
58         while (unlikely(len > 64)) {
59                 __uint128_t tmp1, tmp2, tmp3, tmp4;
60
61                 tmp1 = READ_ONCE_NOCHECK(*(__uint128_t *)ptr);
62                 tmp2 = READ_ONCE_NOCHECK(*(__uint128_t *)(ptr + 2));
63                 tmp3 = READ_ONCE_NOCHECK(*(__uint128_t *)(ptr + 4));
64                 tmp4 = READ_ONCE_NOCHECK(*(__uint128_t *)(ptr + 6));
65
66                 len -= 64;
67                 ptr += 8;
68
69                 /* This is the "don't dump the carry flag into a GPR" idiom */
70                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
71                 tmp2 += (tmp2 >> 64) | (tmp2 << 64);
72                 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
73                 tmp4 += (tmp4 >> 64) | (tmp4 << 64);
74                 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
75                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
76                 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
77                 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
78                 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
79                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
80                 tmp1 = ((tmp1 >> 64) << 64) | sum64;
81                 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
82                 sum64 = tmp1 >> 64;
83         }
84         while (len > 8) {
85                 __uint128_t tmp;
86
87                 sum64 = accumulate(sum64, data);
88                 tmp = READ_ONCE_NOCHECK(*(__uint128_t *)ptr);
89
90                 len -= 16;
91                 ptr += 2;
92
93 #ifdef __LITTLE_ENDIAN
94                 data = tmp >> 64;
95                 sum64 = accumulate(sum64, tmp);
96 #else
97                 data = tmp;
98                 sum64 = accumulate(sum64, tmp >> 64);
99 #endif
100         }
101         if (len > 0) {
102                 sum64 = accumulate(sum64, data);
103                 data = READ_ONCE_NOCHECK(*ptr);
104                 len -= 8;
105         }
106         /*
107          * Tail: zero any over-read bytes similarly to the head, again
108          * preserving odd/even alignment.
109          */
110         shift = len * -8;
111 #ifdef __LITTLE_ENDIAN
112         data = (data << shift) >> shift;
113 #else
114         data = (data >> shift) << shift;
115 #endif
116         sum64 = accumulate(sum64, data);
117
118         /* Finally, folding */
119         sum64 += (sum64 >> 32) | (sum64 << 32);
120         sum = sum64 >> 32;
121         sum += (sum >> 16) | (sum << 16);
122         if (offset & 1)
123                 return (u16)swab32(sum);
124
125         return sum >> 16;
126 }
127
128 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
129                         const struct in6_addr *daddr,
130                         __u32 len, __u8 proto, __wsum csum)
131 {
132         __uint128_t src, dst;
133         u64 sum = (__force u64)csum;
134
135         src = *(const __uint128_t *)saddr->s6_addr;
136         dst = *(const __uint128_t *)daddr->s6_addr;
137
138         sum += (__force u32)htonl(len);
139 #ifdef __LITTLE_ENDIAN
140         sum += (u32)proto << 24;
141 #else
142         sum += proto;
143 #endif
144         src += (src >> 64) | (src << 64);
145         dst += (dst >> 64) | (dst << 64);
146
147         sum = accumulate(sum, src >> 64);
148         sum = accumulate(sum, dst >> 64);
149
150         sum += ((sum >> 32) | (sum << 32));
151         return csum_fold((__force __wsum)(sum >> 32));
152 }
153 EXPORT_SYMBOL(csum_ipv6_magic);