ARM: multi_v5_defconfig: make DaVinci part of the ARM v5 multiplatform build
[linux-2.6-microblaze.git] / net / rxrpc / rxkad.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Kerberos-based RxRPC security
3  *
4  * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <crypto/skcipher.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/skbuff.h>
14 #include <linux/udp.h>
15 #include <linux/scatterlist.h>
16 #include <linux/ctype.h>
17 #include <linux/slab.h>
18 #include <net/sock.h>
19 #include <net/af_rxrpc.h>
20 #include <keys/rxrpc-type.h>
21 #include "ar-internal.h"
22
23 #define RXKAD_VERSION                   2
24 #define MAXKRB5TICKETLEN                1024
25 #define RXKAD_TKT_TYPE_KERBEROS_V5      256
26 #define ANAME_SZ                        40      /* size of authentication name */
27 #define INST_SZ                         40      /* size of principal's instance */
28 #define REALM_SZ                        40      /* size of principal's auth domain */
29 #define SNAME_SZ                        40      /* size of service name */
30
31 struct rxkad_level1_hdr {
32         __be32  data_size;      /* true data size (excluding padding) */
33 };
34
35 struct rxkad_level2_hdr {
36         __be32  data_size;      /* true data size (excluding padding) */
37         __be32  checksum;       /* decrypted data checksum */
38 };
39
40 /*
41  * this holds a pinned cipher so that keventd doesn't get called by the cipher
42  * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE
43  * packets
44  */
45 static struct crypto_sync_skcipher *rxkad_ci;
46 static DEFINE_MUTEX(rxkad_ci_mutex);
47
48 /*
49  * initialise connection security
50  */
51 static int rxkad_init_connection_security(struct rxrpc_connection *conn)
52 {
53         struct crypto_sync_skcipher *ci;
54         struct rxrpc_key_token *token;
55         int ret;
56
57         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->params.key));
58
59         token = conn->params.key->payload.data[0];
60         conn->security_ix = token->security_index;
61
62         ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
63         if (IS_ERR(ci)) {
64                 _debug("no cipher");
65                 ret = PTR_ERR(ci);
66                 goto error;
67         }
68
69         if (crypto_sync_skcipher_setkey(ci, token->kad->session_key,
70                                    sizeof(token->kad->session_key)) < 0)
71                 BUG();
72
73         switch (conn->params.security_level) {
74         case RXRPC_SECURITY_PLAIN:
75                 break;
76         case RXRPC_SECURITY_AUTH:
77                 conn->size_align = 8;
78                 conn->security_size = sizeof(struct rxkad_level1_hdr);
79                 break;
80         case RXRPC_SECURITY_ENCRYPT:
81                 conn->size_align = 8;
82                 conn->security_size = sizeof(struct rxkad_level2_hdr);
83                 break;
84         default:
85                 ret = -EKEYREJECTED;
86                 goto error;
87         }
88
89         conn->cipher = ci;
90         ret = 0;
91 error:
92         _leave(" = %d", ret);
93         return ret;
94 }
95
96 /*
97  * prime the encryption state with the invariant parts of a connection's
98  * description
99  */
100 static int rxkad_prime_packet_security(struct rxrpc_connection *conn)
101 {
102         struct rxrpc_key_token *token;
103         SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher);
104         struct scatterlist sg;
105         struct rxrpc_crypt iv;
106         __be32 *tmpbuf;
107         size_t tmpsize = 4 * sizeof(__be32);
108
109         _enter("");
110
111         if (!conn->params.key)
112                 return 0;
113
114         tmpbuf = kmalloc(tmpsize, GFP_KERNEL);
115         if (!tmpbuf)
116                 return -ENOMEM;
117
118         token = conn->params.key->payload.data[0];
119         memcpy(&iv, token->kad->session_key, sizeof(iv));
120
121         tmpbuf[0] = htonl(conn->proto.epoch);
122         tmpbuf[1] = htonl(conn->proto.cid);
123         tmpbuf[2] = 0;
124         tmpbuf[3] = htonl(conn->security_ix);
125
126         sg_init_one(&sg, tmpbuf, tmpsize);
127         skcipher_request_set_sync_tfm(req, conn->cipher);
128         skcipher_request_set_callback(req, 0, NULL, NULL);
129         skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x);
130         crypto_skcipher_encrypt(req);
131         skcipher_request_zero(req);
132
133         memcpy(&conn->csum_iv, tmpbuf + 2, sizeof(conn->csum_iv));
134         kfree(tmpbuf);
135         _leave(" = 0");
136         return 0;
137 }
138
139 /*
140  * partially encrypt a packet (level 1 security)
141  */
142 static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
143                                     struct sk_buff *skb,
144                                     u32 data_size,
145                                     void *sechdr,
146                                     struct skcipher_request *req)
147 {
148         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
149         struct rxkad_level1_hdr hdr;
150         struct rxrpc_crypt iv;
151         struct scatterlist sg;
152         u16 check;
153
154         _enter("");
155
156         check = sp->hdr.seq ^ call->call_id;
157         data_size |= (u32)check << 16;
158
159         hdr.data_size = htonl(data_size);
160         memcpy(sechdr, &hdr, sizeof(hdr));
161
162         /* start the encryption afresh */
163         memset(&iv, 0, sizeof(iv));
164
165         sg_init_one(&sg, sechdr, 8);
166         skcipher_request_set_sync_tfm(req, call->conn->cipher);
167         skcipher_request_set_callback(req, 0, NULL, NULL);
168         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
169         crypto_skcipher_encrypt(req);
170         skcipher_request_zero(req);
171
172         _leave(" = 0");
173         return 0;
174 }
175
176 /*
177  * wholly encrypt a packet (level 2 security)
178  */
179 static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call,
180                                        struct sk_buff *skb,
181                                        u32 data_size,
182                                        void *sechdr,
183                                        struct skcipher_request *req)
184 {
185         const struct rxrpc_key_token *token;
186         struct rxkad_level2_hdr rxkhdr;
187         struct rxrpc_skb_priv *sp;
188         struct rxrpc_crypt iv;
189         struct scatterlist sg[16];
190         struct sk_buff *trailer;
191         unsigned int len;
192         u16 check;
193         int nsg;
194         int err;
195
196         sp = rxrpc_skb(skb);
197
198         _enter("");
199
200         check = sp->hdr.seq ^ call->call_id;
201
202         rxkhdr.data_size = htonl(data_size | (u32)check << 16);
203         rxkhdr.checksum = 0;
204         memcpy(sechdr, &rxkhdr, sizeof(rxkhdr));
205
206         /* encrypt from the session key */
207         token = call->conn->params.key->payload.data[0];
208         memcpy(&iv, token->kad->session_key, sizeof(iv));
209
210         sg_init_one(&sg[0], sechdr, sizeof(rxkhdr));
211         skcipher_request_set_sync_tfm(req, call->conn->cipher);
212         skcipher_request_set_callback(req, 0, NULL, NULL);
213         skcipher_request_set_crypt(req, &sg[0], &sg[0], sizeof(rxkhdr), iv.x);
214         crypto_skcipher_encrypt(req);
215
216         /* we want to encrypt the skbuff in-place */
217         nsg = skb_cow_data(skb, 0, &trailer);
218         err = -ENOMEM;
219         if (nsg < 0 || nsg > 16)
220                 goto out;
221
222         len = data_size + call->conn->size_align - 1;
223         len &= ~(call->conn->size_align - 1);
224
225         sg_init_table(sg, nsg);
226         err = skb_to_sgvec(skb, sg, 0, len);
227         if (unlikely(err < 0))
228                 goto out;
229         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
230         crypto_skcipher_encrypt(req);
231
232         _leave(" = 0");
233         err = 0;
234
235 out:
236         skcipher_request_zero(req);
237         return err;
238 }
239
240 /*
241  * checksum an RxRPC packet header
242  */
243 static int rxkad_secure_packet(struct rxrpc_call *call,
244                                struct sk_buff *skb,
245                                size_t data_size,
246                                void *sechdr)
247 {
248         struct rxrpc_skb_priv *sp;
249         SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
250         struct rxrpc_crypt iv;
251         struct scatterlist sg;
252         u32 x, y;
253         int ret;
254
255         sp = rxrpc_skb(skb);
256
257         _enter("{%d{%x}},{#%u},%zu,",
258                call->debug_id, key_serial(call->conn->params.key),
259                sp->hdr.seq, data_size);
260
261         if (!call->conn->cipher)
262                 return 0;
263
264         ret = key_validate(call->conn->params.key);
265         if (ret < 0)
266                 return ret;
267
268         /* continue encrypting from where we left off */
269         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
270
271         /* calculate the security checksum */
272         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
273         x |= sp->hdr.seq & 0x3fffffff;
274         call->crypto_buf[0] = htonl(call->call_id);
275         call->crypto_buf[1] = htonl(x);
276
277         sg_init_one(&sg, call->crypto_buf, 8);
278         skcipher_request_set_sync_tfm(req, call->conn->cipher);
279         skcipher_request_set_callback(req, 0, NULL, NULL);
280         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
281         crypto_skcipher_encrypt(req);
282         skcipher_request_zero(req);
283
284         y = ntohl(call->crypto_buf[1]);
285         y = (y >> 16) & 0xffff;
286         if (y == 0)
287                 y = 1; /* zero checksums are not permitted */
288         sp->hdr.cksum = y;
289
290         switch (call->conn->params.security_level) {
291         case RXRPC_SECURITY_PLAIN:
292                 ret = 0;
293                 break;
294         case RXRPC_SECURITY_AUTH:
295                 ret = rxkad_secure_packet_auth(call, skb, data_size, sechdr,
296                                                req);
297                 break;
298         case RXRPC_SECURITY_ENCRYPT:
299                 ret = rxkad_secure_packet_encrypt(call, skb, data_size,
300                                                   sechdr, req);
301                 break;
302         default:
303                 ret = -EPERM;
304                 break;
305         }
306
307         _leave(" = %d [set %hx]", ret, y);
308         return ret;
309 }
310
311 /*
312  * decrypt partial encryption on a packet (level 1 security)
313  */
314 static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
315                                  unsigned int offset, unsigned int len,
316                                  rxrpc_seq_t seq,
317                                  struct skcipher_request *req)
318 {
319         struct rxkad_level1_hdr sechdr;
320         struct rxrpc_crypt iv;
321         struct scatterlist sg[16];
322         struct sk_buff *trailer;
323         bool aborted;
324         u32 data_size, buf;
325         u16 check;
326         int nsg, ret;
327
328         _enter("");
329
330         if (len < 8) {
331                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H",
332                                            RXKADSEALEDINCON);
333                 goto protocol_error;
334         }
335
336         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
337          * directly into the target buffer.
338          */
339         nsg = skb_cow_data(skb, 0, &trailer);
340         if (nsg < 0 || nsg > 16)
341                 goto nomem;
342
343         sg_init_table(sg, nsg);
344         ret = skb_to_sgvec(skb, sg, offset, 8);
345         if (unlikely(ret < 0))
346                 return ret;
347
348         /* start the decryption afresh */
349         memset(&iv, 0, sizeof(iv));
350
351         skcipher_request_set_sync_tfm(req, call->conn->cipher);
352         skcipher_request_set_callback(req, 0, NULL, NULL);
353         skcipher_request_set_crypt(req, sg, sg, 8, iv.x);
354         crypto_skcipher_decrypt(req);
355         skcipher_request_zero(req);
356
357         /* Extract the decrypted packet length */
358         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
359                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1",
360                                              RXKADDATALEN);
361                 goto protocol_error;
362         }
363         offset += sizeof(sechdr);
364         len -= sizeof(sechdr);
365
366         buf = ntohl(sechdr.data_size);
367         data_size = buf & 0xffff;
368
369         check = buf >> 16;
370         check ^= seq ^ call->call_id;
371         check &= 0xffff;
372         if (check != 0) {
373                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C",
374                                              RXKADSEALEDINCON);
375                 goto protocol_error;
376         }
377
378         if (data_size > len) {
379                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L",
380                                              RXKADDATALEN);
381                 goto protocol_error;
382         }
383
384         _leave(" = 0 [dlen=%x]", data_size);
385         return 0;
386
387 protocol_error:
388         if (aborted)
389                 rxrpc_send_abort_packet(call);
390         return -EPROTO;
391
392 nomem:
393         _leave(" = -ENOMEM");
394         return -ENOMEM;
395 }
396
397 /*
398  * wholly decrypt a packet (level 2 security)
399  */
400 static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
401                                  unsigned int offset, unsigned int len,
402                                  rxrpc_seq_t seq,
403                                  struct skcipher_request *req)
404 {
405         const struct rxrpc_key_token *token;
406         struct rxkad_level2_hdr sechdr;
407         struct rxrpc_crypt iv;
408         struct scatterlist _sg[4], *sg;
409         struct sk_buff *trailer;
410         bool aborted;
411         u32 data_size, buf;
412         u16 check;
413         int nsg, ret;
414
415         _enter(",{%d}", skb->len);
416
417         if (len < 8) {
418                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H",
419                                              RXKADSEALEDINCON);
420                 goto protocol_error;
421         }
422
423         /* Decrypt the skbuff in-place.  TODO: We really want to decrypt
424          * directly into the target buffer.
425          */
426         nsg = skb_cow_data(skb, 0, &trailer);
427         if (nsg < 0)
428                 goto nomem;
429
430         sg = _sg;
431         if (unlikely(nsg > 4)) {
432                 sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO);
433                 if (!sg)
434                         goto nomem;
435         }
436
437         sg_init_table(sg, nsg);
438         ret = skb_to_sgvec(skb, sg, offset, len);
439         if (unlikely(ret < 0)) {
440                 if (sg != _sg)
441                         kfree(sg);
442                 return ret;
443         }
444
445         /* decrypt from the session key */
446         token = call->conn->params.key->payload.data[0];
447         memcpy(&iv, token->kad->session_key, sizeof(iv));
448
449         skcipher_request_set_sync_tfm(req, call->conn->cipher);
450         skcipher_request_set_callback(req, 0, NULL, NULL);
451         skcipher_request_set_crypt(req, sg, sg, len, iv.x);
452         crypto_skcipher_decrypt(req);
453         skcipher_request_zero(req);
454         if (sg != _sg)
455                 kfree(sg);
456
457         /* Extract the decrypted packet length */
458         if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
459                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2",
460                                              RXKADDATALEN);
461                 goto protocol_error;
462         }
463         offset += sizeof(sechdr);
464         len -= sizeof(sechdr);
465
466         buf = ntohl(sechdr.data_size);
467         data_size = buf & 0xffff;
468
469         check = buf >> 16;
470         check ^= seq ^ call->call_id;
471         check &= 0xffff;
472         if (check != 0) {
473                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C",
474                                              RXKADSEALEDINCON);
475                 goto protocol_error;
476         }
477
478         if (data_size > len) {
479                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L",
480                                              RXKADDATALEN);
481                 goto protocol_error;
482         }
483
484         _leave(" = 0 [dlen=%x]", data_size);
485         return 0;
486
487 protocol_error:
488         if (aborted)
489                 rxrpc_send_abort_packet(call);
490         return -EPROTO;
491
492 nomem:
493         _leave(" = -ENOMEM");
494         return -ENOMEM;
495 }
496
497 /*
498  * Verify the security on a received packet or subpacket (if part of a
499  * jumbo packet).
500  */
501 static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
502                                unsigned int offset, unsigned int len,
503                                rxrpc_seq_t seq, u16 expected_cksum)
504 {
505         SYNC_SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
506         struct rxrpc_crypt iv;
507         struct scatterlist sg;
508         bool aborted;
509         u16 cksum;
510         u32 x, y;
511
512         _enter("{%d{%x}},{#%u}",
513                call->debug_id, key_serial(call->conn->params.key), seq);
514
515         if (!call->conn->cipher)
516                 return 0;
517
518         /* continue encrypting from where we left off */
519         memcpy(&iv, call->conn->csum_iv.x, sizeof(iv));
520
521         /* validate the security checksum */
522         x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT);
523         x |= seq & 0x3fffffff;
524         call->crypto_buf[0] = htonl(call->call_id);
525         call->crypto_buf[1] = htonl(x);
526
527         sg_init_one(&sg, call->crypto_buf, 8);
528         skcipher_request_set_sync_tfm(req, call->conn->cipher);
529         skcipher_request_set_callback(req, 0, NULL, NULL);
530         skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x);
531         crypto_skcipher_encrypt(req);
532         skcipher_request_zero(req);
533
534         y = ntohl(call->crypto_buf[1]);
535         cksum = (y >> 16) & 0xffff;
536         if (cksum == 0)
537                 cksum = 1; /* zero checksums are not permitted */
538
539         if (cksum != expected_cksum) {
540                 aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK",
541                                              RXKADSEALEDINCON);
542                 goto protocol_error;
543         }
544
545         switch (call->conn->params.security_level) {
546         case RXRPC_SECURITY_PLAIN:
547                 return 0;
548         case RXRPC_SECURITY_AUTH:
549                 return rxkad_verify_packet_1(call, skb, offset, len, seq, req);
550         case RXRPC_SECURITY_ENCRYPT:
551                 return rxkad_verify_packet_2(call, skb, offset, len, seq, req);
552         default:
553                 return -ENOANO;
554         }
555
556 protocol_error:
557         if (aborted)
558                 rxrpc_send_abort_packet(call);
559         return -EPROTO;
560 }
561
562 /*
563  * Locate the data contained in a packet that was partially encrypted.
564  */
565 static void rxkad_locate_data_1(struct rxrpc_call *call, struct sk_buff *skb,
566                                 unsigned int *_offset, unsigned int *_len)
567 {
568         struct rxkad_level1_hdr sechdr;
569
570         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
571                 BUG();
572         *_offset += sizeof(sechdr);
573         *_len = ntohl(sechdr.data_size) & 0xffff;
574 }
575
576 /*
577  * Locate the data contained in a packet that was completely encrypted.
578  */
579 static void rxkad_locate_data_2(struct rxrpc_call *call, struct sk_buff *skb,
580                                 unsigned int *_offset, unsigned int *_len)
581 {
582         struct rxkad_level2_hdr sechdr;
583
584         if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0)
585                 BUG();
586         *_offset += sizeof(sechdr);
587         *_len = ntohl(sechdr.data_size) & 0xffff;
588 }
589
590 /*
591  * Locate the data contained in an already decrypted packet.
592  */
593 static void rxkad_locate_data(struct rxrpc_call *call, struct sk_buff *skb,
594                               unsigned int *_offset, unsigned int *_len)
595 {
596         switch (call->conn->params.security_level) {
597         case RXRPC_SECURITY_AUTH:
598                 rxkad_locate_data_1(call, skb, _offset, _len);
599                 return;
600         case RXRPC_SECURITY_ENCRYPT:
601                 rxkad_locate_data_2(call, skb, _offset, _len);
602                 return;
603         default:
604                 return;
605         }
606 }
607
608 /*
609  * issue a challenge
610  */
611 static int rxkad_issue_challenge(struct rxrpc_connection *conn)
612 {
613         struct rxkad_challenge challenge;
614         struct rxrpc_wire_header whdr;
615         struct msghdr msg;
616         struct kvec iov[2];
617         size_t len;
618         u32 serial;
619         int ret;
620
621         _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
622
623         ret = key_validate(conn->params.key);
624         if (ret < 0)
625                 return ret;
626
627         get_random_bytes(&conn->security_nonce, sizeof(conn->security_nonce));
628
629         challenge.version       = htonl(2);
630         challenge.nonce         = htonl(conn->security_nonce);
631         challenge.min_level     = htonl(0);
632         challenge.__padding     = 0;
633
634         msg.msg_name    = &conn->params.peer->srx.transport;
635         msg.msg_namelen = conn->params.peer->srx.transport_len;
636         msg.msg_control = NULL;
637         msg.msg_controllen = 0;
638         msg.msg_flags   = 0;
639
640         whdr.epoch      = htonl(conn->proto.epoch);
641         whdr.cid        = htonl(conn->proto.cid);
642         whdr.callNumber = 0;
643         whdr.seq        = 0;
644         whdr.type       = RXRPC_PACKET_TYPE_CHALLENGE;
645         whdr.flags      = conn->out_clientflag;
646         whdr.userStatus = 0;
647         whdr.securityIndex = conn->security_ix;
648         whdr._rsvd      = 0;
649         whdr.serviceId  = htons(conn->service_id);
650
651         iov[0].iov_base = &whdr;
652         iov[0].iov_len  = sizeof(whdr);
653         iov[1].iov_base = &challenge;
654         iov[1].iov_len  = sizeof(challenge);
655
656         len = iov[0].iov_len + iov[1].iov_len;
657
658         serial = atomic_inc_return(&conn->serial);
659         whdr.serial = htonl(serial);
660         _proto("Tx CHALLENGE %%%u", serial);
661
662         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len);
663         if (ret < 0) {
664                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
665                                     rxrpc_tx_point_rxkad_challenge);
666                 return -EAGAIN;
667         }
668
669         conn->params.peer->last_tx_at = ktime_get_seconds();
670         trace_rxrpc_tx_packet(conn->debug_id, &whdr,
671                               rxrpc_tx_point_rxkad_challenge);
672         _leave(" = 0");
673         return 0;
674 }
675
676 /*
677  * send a Kerberos security response
678  */
679 static int rxkad_send_response(struct rxrpc_connection *conn,
680                                struct rxrpc_host_header *hdr,
681                                struct rxkad_response *resp,
682                                const struct rxkad_key *s2)
683 {
684         struct rxrpc_wire_header whdr;
685         struct msghdr msg;
686         struct kvec iov[3];
687         size_t len;
688         u32 serial;
689         int ret;
690
691         _enter("");
692
693         msg.msg_name    = &conn->params.peer->srx.transport;
694         msg.msg_namelen = conn->params.peer->srx.transport_len;
695         msg.msg_control = NULL;
696         msg.msg_controllen = 0;
697         msg.msg_flags   = 0;
698
699         memset(&whdr, 0, sizeof(whdr));
700         whdr.epoch      = htonl(hdr->epoch);
701         whdr.cid        = htonl(hdr->cid);
702         whdr.type       = RXRPC_PACKET_TYPE_RESPONSE;
703         whdr.flags      = conn->out_clientflag;
704         whdr.securityIndex = hdr->securityIndex;
705         whdr.serviceId  = htons(hdr->serviceId);
706
707         iov[0].iov_base = &whdr;
708         iov[0].iov_len  = sizeof(whdr);
709         iov[1].iov_base = resp;
710         iov[1].iov_len  = sizeof(*resp);
711         iov[2].iov_base = (void *)s2->ticket;
712         iov[2].iov_len  = s2->ticket_len;
713
714         len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len;
715
716         serial = atomic_inc_return(&conn->serial);
717         whdr.serial = htonl(serial);
718         _proto("Tx RESPONSE %%%u", serial);
719
720         ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 3, len);
721         if (ret < 0) {
722                 trace_rxrpc_tx_fail(conn->debug_id, serial, ret,
723                                     rxrpc_tx_point_rxkad_response);
724                 return -EAGAIN;
725         }
726
727         conn->params.peer->last_tx_at = ktime_get_seconds();
728         _leave(" = 0");
729         return 0;
730 }
731
732 /*
733  * calculate the response checksum
734  */
735 static void rxkad_calc_response_checksum(struct rxkad_response *response)
736 {
737         u32 csum = 1000003;
738         int loop;
739         u8 *p = (u8 *) response;
740
741         for (loop = sizeof(*response); loop > 0; loop--)
742                 csum = csum * 0x10204081 + *p++;
743
744         response->encrypted.checksum = htonl(csum);
745 }
746
747 /*
748  * encrypt the response packet
749  */
750 static void rxkad_encrypt_response(struct rxrpc_connection *conn,
751                                    struct rxkad_response *resp,
752                                    const struct rxkad_key *s2)
753 {
754         SYNC_SKCIPHER_REQUEST_ON_STACK(req, conn->cipher);
755         struct rxrpc_crypt iv;
756         struct scatterlist sg[1];
757
758         /* continue encrypting from where we left off */
759         memcpy(&iv, s2->session_key, sizeof(iv));
760
761         sg_init_table(sg, 1);
762         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
763         skcipher_request_set_sync_tfm(req, conn->cipher);
764         skcipher_request_set_callback(req, 0, NULL, NULL);
765         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
766         crypto_skcipher_encrypt(req);
767         skcipher_request_zero(req);
768 }
769
770 /*
771  * respond to a challenge packet
772  */
773 static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
774                                       struct sk_buff *skb,
775                                       u32 *_abort_code)
776 {
777         const struct rxrpc_key_token *token;
778         struct rxkad_challenge challenge;
779         struct rxkad_response *resp;
780         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
781         const char *eproto;
782         u32 version, nonce, min_level, abort_code;
783         int ret;
784
785         _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
786
787         eproto = tracepoint_string("chall_no_key");
788         abort_code = RX_PROTOCOL_ERROR;
789         if (!conn->params.key)
790                 goto protocol_error;
791
792         abort_code = RXKADEXPIRED;
793         ret = key_validate(conn->params.key);
794         if (ret < 0)
795                 goto other_error;
796
797         eproto = tracepoint_string("chall_short");
798         abort_code = RXKADPACKETSHORT;
799         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
800                           &challenge, sizeof(challenge)) < 0)
801                 goto protocol_error;
802
803         version = ntohl(challenge.version);
804         nonce = ntohl(challenge.nonce);
805         min_level = ntohl(challenge.min_level);
806
807         _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }",
808                sp->hdr.serial, version, nonce, min_level);
809
810         eproto = tracepoint_string("chall_ver");
811         abort_code = RXKADINCONSISTENCY;
812         if (version != RXKAD_VERSION)
813                 goto protocol_error;
814
815         abort_code = RXKADLEVELFAIL;
816         ret = -EACCES;
817         if (conn->params.security_level < min_level)
818                 goto other_error;
819
820         token = conn->params.key->payload.data[0];
821
822         /* build the response packet */
823         resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
824         if (!resp)
825                 return -ENOMEM;
826
827         resp->version                   = htonl(RXKAD_VERSION);
828         resp->encrypted.epoch           = htonl(conn->proto.epoch);
829         resp->encrypted.cid             = htonl(conn->proto.cid);
830         resp->encrypted.securityIndex   = htonl(conn->security_ix);
831         resp->encrypted.inc_nonce       = htonl(nonce + 1);
832         resp->encrypted.level           = htonl(conn->params.security_level);
833         resp->kvno                      = htonl(token->kad->kvno);
834         resp->ticket_len                = htonl(token->kad->ticket_len);
835         resp->encrypted.call_id[0]      = htonl(conn->channels[0].call_counter);
836         resp->encrypted.call_id[1]      = htonl(conn->channels[1].call_counter);
837         resp->encrypted.call_id[2]      = htonl(conn->channels[2].call_counter);
838         resp->encrypted.call_id[3]      = htonl(conn->channels[3].call_counter);
839
840         /* calculate the response checksum and then do the encryption */
841         rxkad_calc_response_checksum(resp);
842         rxkad_encrypt_response(conn, resp, token->kad);
843         ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad);
844         kfree(resp);
845         return ret;
846
847 protocol_error:
848         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
849         ret = -EPROTO;
850 other_error:
851         *_abort_code = abort_code;
852         return ret;
853 }
854
855 /*
856  * decrypt the kerberos IV ticket in the response
857  */
858 static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
859                                 struct sk_buff *skb,
860                                 void *ticket, size_t ticket_len,
861                                 struct rxrpc_crypt *_session_key,
862                                 time64_t *_expiry,
863                                 u32 *_abort_code)
864 {
865         struct skcipher_request *req;
866         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
867         struct rxrpc_crypt iv, key;
868         struct scatterlist sg[1];
869         struct in_addr addr;
870         unsigned int life;
871         const char *eproto;
872         time64_t issue, now;
873         bool little_endian;
874         int ret;
875         u32 abort_code;
876         u8 *p, *q, *name, *end;
877
878         _enter("{%d},{%x}", conn->debug_id, key_serial(conn->server_key));
879
880         *_expiry = 0;
881
882         ret = key_validate(conn->server_key);
883         if (ret < 0) {
884                 switch (ret) {
885                 case -EKEYEXPIRED:
886                         abort_code = RXKADEXPIRED;
887                         goto other_error;
888                 default:
889                         abort_code = RXKADNOAUTH;
890                         goto other_error;
891                 }
892         }
893
894         ASSERT(conn->server_key->payload.data[0] != NULL);
895         ASSERTCMP((unsigned long) ticket & 7UL, ==, 0);
896
897         memcpy(&iv, &conn->server_key->payload.data[2], sizeof(iv));
898
899         ret = -ENOMEM;
900         req = skcipher_request_alloc(conn->server_key->payload.data[0],
901                                      GFP_NOFS);
902         if (!req)
903                 goto temporary_error;
904
905         sg_init_one(&sg[0], ticket, ticket_len);
906         skcipher_request_set_callback(req, 0, NULL, NULL);
907         skcipher_request_set_crypt(req, sg, sg, ticket_len, iv.x);
908         crypto_skcipher_decrypt(req);
909         skcipher_request_free(req);
910
911         p = ticket;
912         end = p + ticket_len;
913
914 #define Z(field)                                        \
915         ({                                              \
916                 u8 *__str = p;                          \
917                 eproto = tracepoint_string("rxkad_bad_"#field); \
918                 q = memchr(p, 0, end - p);              \
919                 if (!q || q - p > (field##_SZ))         \
920                         goto bad_ticket;                \
921                 for (; p < q; p++)                      \
922                         if (!isprint(*p))               \
923                                 goto bad_ticket;        \
924                 p++;                                    \
925                 __str;                                  \
926         })
927
928         /* extract the ticket flags */
929         _debug("KIV FLAGS: %x", *p);
930         little_endian = *p & 1;
931         p++;
932
933         /* extract the authentication name */
934         name = Z(ANAME);
935         _debug("KIV ANAME: %s", name);
936
937         /* extract the principal's instance */
938         name = Z(INST);
939         _debug("KIV INST : %s", name);
940
941         /* extract the principal's authentication domain */
942         name = Z(REALM);
943         _debug("KIV REALM: %s", name);
944
945         eproto = tracepoint_string("rxkad_bad_len");
946         if (end - p < 4 + 8 + 4 + 2)
947                 goto bad_ticket;
948
949         /* get the IPv4 address of the entity that requested the ticket */
950         memcpy(&addr, p, sizeof(addr));
951         p += 4;
952         _debug("KIV ADDR : %pI4", &addr);
953
954         /* get the session key from the ticket */
955         memcpy(&key, p, sizeof(key));
956         p += 8;
957         _debug("KIV KEY  : %08x %08x", ntohl(key.n[0]), ntohl(key.n[1]));
958         memcpy(_session_key, &key, sizeof(key));
959
960         /* get the ticket's lifetime */
961         life = *p++ * 5 * 60;
962         _debug("KIV LIFE : %u", life);
963
964         /* get the issue time of the ticket */
965         if (little_endian) {
966                 __le32 stamp;
967                 memcpy(&stamp, p, 4);
968                 issue = rxrpc_u32_to_time64(le32_to_cpu(stamp));
969         } else {
970                 __be32 stamp;
971                 memcpy(&stamp, p, 4);
972                 issue = rxrpc_u32_to_time64(be32_to_cpu(stamp));
973         }
974         p += 4;
975         now = ktime_get_real_seconds();
976         _debug("KIV ISSUE: %llx [%llx]", issue, now);
977
978         /* check the ticket is in date */
979         if (issue > now) {
980                 abort_code = RXKADNOAUTH;
981                 ret = -EKEYREJECTED;
982                 goto other_error;
983         }
984
985         if (issue < now - life) {
986                 abort_code = RXKADEXPIRED;
987                 ret = -EKEYEXPIRED;
988                 goto other_error;
989         }
990
991         *_expiry = issue + life;
992
993         /* get the service name */
994         name = Z(SNAME);
995         _debug("KIV SNAME: %s", name);
996
997         /* get the service instance name */
998         name = Z(INST);
999         _debug("KIV SINST: %s", name);
1000         return 0;
1001
1002 bad_ticket:
1003         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1004         abort_code = RXKADBADTICKET;
1005         ret = -EPROTO;
1006 other_error:
1007         *_abort_code = abort_code;
1008         return ret;
1009 temporary_error:
1010         return ret;
1011 }
1012
1013 /*
1014  * decrypt the response packet
1015  */
1016 static void rxkad_decrypt_response(struct rxrpc_connection *conn,
1017                                    struct rxkad_response *resp,
1018                                    const struct rxrpc_crypt *session_key)
1019 {
1020         SYNC_SKCIPHER_REQUEST_ON_STACK(req, rxkad_ci);
1021         struct scatterlist sg[1];
1022         struct rxrpc_crypt iv;
1023
1024         _enter(",,%08x%08x",
1025                ntohl(session_key->n[0]), ntohl(session_key->n[1]));
1026
1027         ASSERT(rxkad_ci != NULL);
1028
1029         mutex_lock(&rxkad_ci_mutex);
1030         if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x,
1031                                    sizeof(*session_key)) < 0)
1032                 BUG();
1033
1034         memcpy(&iv, session_key, sizeof(iv));
1035
1036         sg_init_table(sg, 1);
1037         sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted));
1038         skcipher_request_set_sync_tfm(req, rxkad_ci);
1039         skcipher_request_set_callback(req, 0, NULL, NULL);
1040         skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x);
1041         crypto_skcipher_decrypt(req);
1042         skcipher_request_zero(req);
1043
1044         mutex_unlock(&rxkad_ci_mutex);
1045
1046         _leave("");
1047 }
1048
1049 /*
1050  * verify a response
1051  */
1052 static int rxkad_verify_response(struct rxrpc_connection *conn,
1053                                  struct sk_buff *skb,
1054                                  u32 *_abort_code)
1055 {
1056         struct rxkad_response *response;
1057         struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
1058         struct rxrpc_crypt session_key;
1059         const char *eproto;
1060         time64_t expiry;
1061         void *ticket;
1062         u32 abort_code, version, kvno, ticket_len, level;
1063         __be32 csum;
1064         int ret, i;
1065
1066         _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
1067
1068         ret = -ENOMEM;
1069         response = kzalloc(sizeof(struct rxkad_response), GFP_NOFS);
1070         if (!response)
1071                 goto temporary_error;
1072
1073         eproto = tracepoint_string("rxkad_rsp_short");
1074         abort_code = RXKADPACKETSHORT;
1075         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1076                           response, sizeof(*response)) < 0)
1077                 goto protocol_error;
1078         if (!pskb_pull(skb, sizeof(*response)))
1079                 BUG();
1080
1081         version = ntohl(response->version);
1082         ticket_len = ntohl(response->ticket_len);
1083         kvno = ntohl(response->kvno);
1084         _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }",
1085                sp->hdr.serial, version, kvno, ticket_len);
1086
1087         eproto = tracepoint_string("rxkad_rsp_ver");
1088         abort_code = RXKADINCONSISTENCY;
1089         if (version != RXKAD_VERSION)
1090                 goto protocol_error;
1091
1092         eproto = tracepoint_string("rxkad_rsp_tktlen");
1093         abort_code = RXKADTICKETLEN;
1094         if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN)
1095                 goto protocol_error;
1096
1097         eproto = tracepoint_string("rxkad_rsp_unkkey");
1098         abort_code = RXKADUNKNOWNKEY;
1099         if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5)
1100                 goto protocol_error;
1101
1102         /* extract the kerberos ticket and decrypt and decode it */
1103         ret = -ENOMEM;
1104         ticket = kmalloc(ticket_len, GFP_NOFS);
1105         if (!ticket)
1106                 goto temporary_error;
1107
1108         eproto = tracepoint_string("rxkad_tkt_short");
1109         abort_code = RXKADPACKETSHORT;
1110         if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
1111                           ticket, ticket_len) < 0)
1112                 goto protocol_error_free;
1113
1114         ret = rxkad_decrypt_ticket(conn, skb, ticket, ticket_len, &session_key,
1115                                    &expiry, _abort_code);
1116         if (ret < 0)
1117                 goto temporary_error_free_resp;
1118
1119         /* use the session key from inside the ticket to decrypt the
1120          * response */
1121         rxkad_decrypt_response(conn, response, &session_key);
1122
1123         eproto = tracepoint_string("rxkad_rsp_param");
1124         abort_code = RXKADSEALEDINCON;
1125         if (ntohl(response->encrypted.epoch) != conn->proto.epoch)
1126                 goto protocol_error_free;
1127         if (ntohl(response->encrypted.cid) != conn->proto.cid)
1128                 goto protocol_error_free;
1129         if (ntohl(response->encrypted.securityIndex) != conn->security_ix)
1130                 goto protocol_error_free;
1131         csum = response->encrypted.checksum;
1132         response->encrypted.checksum = 0;
1133         rxkad_calc_response_checksum(response);
1134         eproto = tracepoint_string("rxkad_rsp_csum");
1135         if (response->encrypted.checksum != csum)
1136                 goto protocol_error_free;
1137
1138         spin_lock(&conn->channel_lock);
1139         for (i = 0; i < RXRPC_MAXCALLS; i++) {
1140                 struct rxrpc_call *call;
1141                 u32 call_id = ntohl(response->encrypted.call_id[i]);
1142
1143                 eproto = tracepoint_string("rxkad_rsp_callid");
1144                 if (call_id > INT_MAX)
1145                         goto protocol_error_unlock;
1146
1147                 eproto = tracepoint_string("rxkad_rsp_callctr");
1148                 if (call_id < conn->channels[i].call_counter)
1149                         goto protocol_error_unlock;
1150
1151                 eproto = tracepoint_string("rxkad_rsp_callst");
1152                 if (call_id > conn->channels[i].call_counter) {
1153                         call = rcu_dereference_protected(
1154                                 conn->channels[i].call,
1155                                 lockdep_is_held(&conn->channel_lock));
1156                         if (call && call->state < RXRPC_CALL_COMPLETE)
1157                                 goto protocol_error_unlock;
1158                         conn->channels[i].call_counter = call_id;
1159                 }
1160         }
1161         spin_unlock(&conn->channel_lock);
1162
1163         eproto = tracepoint_string("rxkad_rsp_seq");
1164         abort_code = RXKADOUTOFSEQUENCE;
1165         if (ntohl(response->encrypted.inc_nonce) != conn->security_nonce + 1)
1166                 goto protocol_error_free;
1167
1168         eproto = tracepoint_string("rxkad_rsp_level");
1169         abort_code = RXKADLEVELFAIL;
1170         level = ntohl(response->encrypted.level);
1171         if (level > RXRPC_SECURITY_ENCRYPT)
1172                 goto protocol_error_free;
1173         conn->params.security_level = level;
1174
1175         /* create a key to hold the security data and expiration time - after
1176          * this the connection security can be handled in exactly the same way
1177          * as for a client connection */
1178         ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno);
1179         if (ret < 0)
1180                 goto temporary_error_free_ticket;
1181
1182         kfree(ticket);
1183         kfree(response);
1184         _leave(" = 0");
1185         return 0;
1186
1187 protocol_error_unlock:
1188         spin_unlock(&conn->channel_lock);
1189 protocol_error_free:
1190         kfree(ticket);
1191 protocol_error:
1192         kfree(response);
1193         trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
1194         *_abort_code = abort_code;
1195         return -EPROTO;
1196
1197 temporary_error_free_ticket:
1198         kfree(ticket);
1199 temporary_error_free_resp:
1200         kfree(response);
1201 temporary_error:
1202         /* Ignore the response packet if we got a temporary error such as
1203          * ENOMEM.  We just want to send the challenge again.  Note that we
1204          * also come out this way if the ticket decryption fails.
1205          */
1206         return ret;
1207 }
1208
1209 /*
1210  * clear the connection security
1211  */
1212 static void rxkad_clear(struct rxrpc_connection *conn)
1213 {
1214         _enter("");
1215
1216         if (conn->cipher)
1217                 crypto_free_sync_skcipher(conn->cipher);
1218 }
1219
1220 /*
1221  * Initialise the rxkad security service.
1222  */
1223 static int rxkad_init(void)
1224 {
1225         /* pin the cipher we need so that the crypto layer doesn't invoke
1226          * keventd to go get it */
1227         rxkad_ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0);
1228         return PTR_ERR_OR_ZERO(rxkad_ci);
1229 }
1230
1231 /*
1232  * Clean up the rxkad security service.
1233  */
1234 static void rxkad_exit(void)
1235 {
1236         if (rxkad_ci)
1237                 crypto_free_sync_skcipher(rxkad_ci);
1238 }
1239
1240 /*
1241  * RxRPC Kerberos-based security
1242  */
1243 const struct rxrpc_security rxkad = {
1244         .name                           = "rxkad",
1245         .security_index                 = RXRPC_SECURITY_RXKAD,
1246         .init                           = rxkad_init,
1247         .exit                           = rxkad_exit,
1248         .init_connection_security       = rxkad_init_connection_security,
1249         .prime_packet_security          = rxkad_prime_packet_security,
1250         .secure_packet                  = rxkad_secure_packet,
1251         .verify_packet                  = rxkad_verify_packet,
1252         .locate_data                    = rxkad_locate_data,
1253         .issue_challenge                = rxkad_issue_challenge,
1254         .respond_to_challenge           = rxkad_respond_to_challenge,
1255         .verify_response                = rxkad_verify_response,
1256         .clear                          = rxkad_clear,
1257 };