rxrpc, afs: Allow afs to pin rxrpc_peer objects
authorDavid Howells <dhowells@redhat.com>
Thu, 19 Oct 2023 11:55:11 +0000 (12:55 +0100)
committerDavid Howells <dhowells@redhat.com>
Sun, 24 Dec 2023 15:22:50 +0000 (15:22 +0000)
Change rxrpc's API such that:

 (1) A new function, rxrpc_kernel_lookup_peer(), is provided to look up an
     rxrpc_peer record for a remote address and a corresponding function,
     rxrpc_kernel_put_peer(), is provided to dispose of it again.

 (2) When setting up a call, the rxrpc_peer object used during a call is
     now passed in rather than being set up by rxrpc_connect_call().  For
     afs, this meenat passing it to rxrpc_kernel_begin_call() rather than
     the full address (the service ID then has to be passed in as a
     separate parameter).

 (3) A new function, rxrpc_kernel_remote_addr(), is added so that afs can
     get a pointer to the transport address for display purposed, and
     another, rxrpc_kernel_remote_srx(), to gain a pointer to the full
     rxrpc address.

 (4) The function to retrieve the RTT from a call, rxrpc_kernel_get_srtt(),
     is then altered to take a peer.  This now returns the RTT or -1 if
     there are insufficient samples.

 (5) Rename rxrpc_kernel_get_peer() to rxrpc_kernel_call_get_peer().

 (6) Provide a new function, rxrpc_kernel_get_peer(), to get a ref on a
     peer the caller already has.

This allows the afs filesystem to pin the rxrpc_peer records that it is
using, allowing faster lookups and pointer comparisons rather than
comparing sockaddr_rxrpc contents.  It also makes it easier to get hold of
the RTT.  The following changes are made to afs:

 (1) The addr_list struct's addrs[] elements now hold a peer struct pointer
     and a service ID rather than a sockaddr_rxrpc.

 (2) When displaying the transport address, rxrpc_kernel_remote_addr() is
     used.

 (3) The port arg is removed from afs_alloc_addrlist() since it's always
     overridden.

 (4) afs_merge_fs_addr4() and afs_merge_fs_addr6() do peer lookup and may
     now return an error that must be handled.

 (5) afs_find_server() now takes a peer pointer to specify the address.

 (6) afs_find_server(), afs_compare_fs_alists() and afs_merge_fs_addr[46]{}
     now do peer pointer comparison rather than address comparison.

Signed-off-by: David Howells <dhowells@redhat.com>
cc: Marc Dionne <marc.dionne@auristor.com>
cc: linux-afs@lists.infradead.org

20 files changed:
fs/afs/addr_list.c
fs/afs/cmservice.c
fs/afs/fs_probe.c
fs/afs/internal.h
fs/afs/proc.c
fs/afs/rotate.c
fs/afs/rxrpc.c
fs/afs/server.c
fs/afs/vl_alias.c
fs/afs/vl_list.c
fs/afs/vl_probe.c
fs/afs/vl_rotate.c
fs/afs/vlclient.c
include/net/af_rxrpc.h
include/trace/events/rxrpc.h
net/rxrpc/af_rxrpc.c
net/rxrpc/ar-internal.h
net/rxrpc/call_object.c
net/rxrpc/peer_object.c
net/rxrpc/sendmsg.c

index ac05a59..519821f 100644 (file)
 #include "internal.h"
 #include "afs_fs.h"
 
+static void afs_free_addrlist(struct rcu_head *rcu)
+{
+       struct afs_addr_list *alist = container_of(rcu, struct afs_addr_list, rcu);
+       unsigned int i;
+
+       for (i = 0; i < alist->nr_addrs; i++)
+               rxrpc_kernel_put_peer(alist->addrs[i].peer);
+}
+
 /*
  * Release an address list.
  */
 void afs_put_addrlist(struct afs_addr_list *alist)
 {
        if (alist && refcount_dec_and_test(&alist->usage))
-               kfree_rcu(alist, rcu);
+               call_rcu(&alist->rcu, afs_free_addrlist);
 }
 
 /*
  * Allocate an address list.
  */
-struct afs_addr_list *afs_alloc_addrlist(unsigned int nr,
-                                        unsigned short service,
-                                        unsigned short port)
+struct afs_addr_list *afs_alloc_addrlist(unsigned int nr, u16 service_id)
 {
        struct afs_addr_list *alist;
        unsigned int i;
 
-       _enter("%u,%u,%u", nr, service, port);
+       _enter("%u,%u", nr, service_id);
 
        if (nr > AFS_MAX_ADDRESSES)
                nr = AFS_MAX_ADDRESSES;
@@ -44,16 +51,8 @@ struct afs_addr_list *afs_alloc_addrlist(unsigned int nr,
        refcount_set(&alist->usage, 1);
        alist->max_addrs = nr;
 
-       for (i = 0; i < nr; i++) {
-               struct sockaddr_rxrpc *srx = &alist->addrs[i].srx;
-               srx->srx_family                 = AF_RXRPC;
-               srx->srx_service                = service;
-               srx->transport_type             = SOCK_DGRAM;
-               srx->transport_len              = sizeof(srx->transport.sin6);
-               srx->transport.sin6.sin6_family = AF_INET6;
-               srx->transport.sin6.sin6_port   = htons(port);
-       }
-
+       for (i = 0; i < nr; i++)
+               alist->addrs[i].service_id = service_id;
        return alist;
 }
 
@@ -126,7 +125,7 @@ struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net,
        if (!vllist->servers[0].server)
                goto error_vl;
 
-       alist = afs_alloc_addrlist(nr, service, AFS_VL_PORT);
+       alist = afs_alloc_addrlist(nr, service);
        if (!alist)
                goto error;
 
@@ -197,9 +196,11 @@ struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *net,
                }
 
                if (family == AF_INET)
-                       afs_merge_fs_addr4(alist, x[0], xport);
+                       ret = afs_merge_fs_addr4(net, alist, x[0], xport);
                else
-                       afs_merge_fs_addr6(alist, x, xport);
+                       ret = afs_merge_fs_addr6(net, alist, x, xport);
+               if (ret < 0)
+                       goto error;
 
        } while (p < end);
 
@@ -271,25 +272,33 @@ struct afs_vlserver_list *afs_dns_query(struct afs_cell *cell, time64_t *_expiry
 /*
  * Merge an IPv4 entry into a fileserver address list.
  */
-void afs_merge_fs_addr4(struct afs_addr_list *alist, __be32 xdr, u16 port)
+int afs_merge_fs_addr4(struct afs_net *net, struct afs_addr_list *alist,
+                      __be32 xdr, u16 port)
 {
-       struct sockaddr_rxrpc *srx;
-       u32 addr = ntohl(xdr);
+       struct sockaddr_rxrpc srx;
+       struct rxrpc_peer *peer;
        int i;
 
        if (alist->nr_addrs >= alist->max_addrs)
-               return;
+               return 0;
 
-       for (i = 0; i < alist->nr_ipv4; i++) {
-               struct sockaddr_in *a = &alist->addrs[i].srx.transport.sin;
-               u32 a_addr = ntohl(a->sin_addr.s_addr);
-               u16 a_port = ntohs(a->sin_port);
+       srx.srx_family = AF_RXRPC;
+       srx.transport_type = SOCK_DGRAM;
+       srx.transport_len = sizeof(srx.transport.sin);
+       srx.transport.sin.sin_family = AF_INET;
+       srx.transport.sin.sin_port = htons(port);
+       srx.transport.sin.sin_addr.s_addr = xdr;
 
-               if (addr == a_addr && port == a_port)
-                       return;
-               if (addr == a_addr && port < a_port)
-                       break;
-               if (addr < a_addr)
+       peer = rxrpc_kernel_lookup_peer(net->socket, &srx, GFP_KERNEL);
+       if (!peer)
+               return -ENOMEM;
+
+       for (i = 0; i < alist->nr_ipv4; i++) {
+               if (peer == alist->addrs[i].peer) {
+                       rxrpc_kernel_put_peer(peer);
+                       return 0;
+               }
+               if (peer <= alist->addrs[i].peer)
                        break;
        }
 
@@ -298,38 +307,42 @@ void afs_merge_fs_addr4(struct afs_addr_list *alist, __be32 xdr, u16 port)
                        alist->addrs + i,
                        sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
 
-       srx = &alist->addrs[i].srx;
-       srx->srx_family = AF_RXRPC;
-       srx->transport_type = SOCK_DGRAM;
-       srx->transport_len = sizeof(srx->transport.sin);
-       srx->transport.sin.sin_family = AF_INET;
-       srx->transport.sin.sin_port = htons(port);
-       srx->transport.sin.sin_addr.s_addr = xdr;
+       alist->addrs[i].peer = peer;
        alist->nr_ipv4++;
        alist->nr_addrs++;
+       return 0;
 }
 
 /*
  * Merge an IPv6 entry into a fileserver address list.
  */
-void afs_merge_fs_addr6(struct afs_addr_list *alist, __be32 *xdr, u16 port)
+int afs_merge_fs_addr6(struct afs_net *net, struct afs_addr_list *alist,
+                      __be32 *xdr, u16 port)
 {
-       struct sockaddr_rxrpc *srx;
-       int i, diff;
+       struct sockaddr_rxrpc srx;
+       struct rxrpc_peer *peer;
+       int i;
 
        if (alist->nr_addrs >= alist->max_addrs)
-               return;
+               return 0;
 
-       for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) {
-               struct sockaddr_in6 *a = &alist->addrs[i].srx.transport.sin6;
-               u16 a_port = ntohs(a->sin6_port);
+       srx.srx_family = AF_RXRPC;
+       srx.transport_type = SOCK_DGRAM;
+       srx.transport_len = sizeof(srx.transport.sin6);
+       srx.transport.sin6.sin6_family = AF_INET6;
+       srx.transport.sin6.sin6_port = htons(port);
+       memcpy(&srx.transport.sin6.sin6_addr, xdr, 16);
 
-               diff = memcmp(xdr, &a->sin6_addr, 16);
-               if (diff == 0 && port == a_port)
-                       return;
-               if (diff == 0 && port < a_port)
-                       break;
-               if (diff < 0)
+       peer = rxrpc_kernel_lookup_peer(net->socket, &srx, GFP_KERNEL);
+       if (!peer)
+               return -ENOMEM;
+
+       for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) {
+               if (peer == alist->addrs[i].peer) {
+                       rxrpc_kernel_put_peer(peer);
+                       return 0;
+               }
+               if (peer <= alist->addrs[i].peer)
                        break;
        }
 
@@ -337,15 +350,9 @@ void afs_merge_fs_addr6(struct afs_addr_list *alist, __be32 *xdr, u16 port)
                memmove(alist->addrs + i + 1,
                        alist->addrs + i,
                        sizeof(alist->addrs[0]) * (alist->nr_addrs - i));
-
-       srx = &alist->addrs[i].srx;
-       srx->srx_family = AF_RXRPC;
-       srx->transport_type = SOCK_DGRAM;
-       srx->transport_len = sizeof(srx->transport.sin6);
-       srx->transport.sin6.sin6_family = AF_INET6;
-       srx->transport.sin6.sin6_port = htons(port);
-       memcpy(&srx->transport.sin6.sin6_addr, xdr, 16);
+       alist->addrs[i].peer = peer;
        alist->nr_addrs++;
+       return 0;
 }
 
 /*
index d4ddb20..99a3f20 100644 (file)
@@ -146,10 +146,11 @@ static int afs_find_cm_server_by_peer(struct afs_call *call)
 {
        struct sockaddr_rxrpc srx;
        struct afs_server *server;
+       struct rxrpc_peer *peer;
 
-       rxrpc_kernel_get_peer(call->net->socket, call->rxcall, &srx);
+       peer = rxrpc_kernel_get_call_peer(call->net->socket, call->rxcall);
 
-       server = afs_find_server(call->net, &srx);
+       server = afs_find_server(call->net, peer);
        if (!server) {
                trace_afs_cm_no_server(call, &srx);
                return 0;
index 3dd2484..58d28b8 100644 (file)
@@ -101,6 +101,7 @@ static void afs_fs_probe_not_done(struct afs_net *net,
 void afs_fileserver_probe_result(struct afs_call *call)
 {
        struct afs_addr_list *alist = call->alist;
+       struct afs_address *addr = &alist->addrs[call->addr_ix];
        struct afs_server *server = call->server;
        unsigned int index = call->addr_ix;
        unsigned int rtt_us = 0, cap0;
@@ -153,12 +154,12 @@ responded:
        if (call->service_id == YFS_FS_SERVICE) {
                server->probe.is_yfs = true;
                set_bit(AFS_SERVER_FL_IS_YFS, &server->flags);
-               alist->addrs[index].srx.srx_service = call->service_id;
+               addr->service_id = call->service_id;
        } else {
                server->probe.not_yfs = true;
                if (!server->probe.is_yfs) {
                        clear_bit(AFS_SERVER_FL_IS_YFS, &server->flags);
-                       alist->addrs[index].srx.srx_service = call->service_id;
+                       addr->service_id = call->service_id;
                }
                cap0 = ntohl(call->tmp);
                if (cap0 & AFS3_VICED_CAPABILITY_64BITFILES)
@@ -167,7 +168,7 @@ responded:
                        clear_bit(AFS_SERVER_FL_HAS_FS64, &server->flags);
        }
 
-       rxrpc_kernel_get_srtt(call->net->socket, call->rxcall, &rtt_us);
+       rtt_us = rxrpc_kernel_get_srtt(addr->peer);
        if (rtt_us < server->probe.rtt) {
                server->probe.rtt = rtt_us;
                server->rtt = rtt_us;
@@ -181,8 +182,8 @@ responded:
 out:
        spin_unlock(&server->probe_lock);
 
-       _debug("probe %pU [%u] %pISpc rtt=%u ret=%d",
-              &server->uuid, index, &alist->addrs[index].srx.transport,
+       _debug("probe %pU [%u] %pISpc rtt=%d ret=%d",
+              &server->uuid, index, rxrpc_kernel_remote_addr(alist->addrs[index].peer),
               rtt_us, ret);
 
        return afs_done_one_fs_probe(call->net, server);
index e2adb31..ec08b4a 100644 (file)
@@ -72,6 +72,11 @@ enum afs_call_state {
        AFS_CALL_COMPLETE,              /* Completed or failed */
 };
 
+struct afs_address {
+       struct rxrpc_peer       *peer;
+       u16                     service_id;
+};
+
 /*
  * List of server addresses.
  */
@@ -87,9 +92,7 @@ struct afs_addr_list {
        enum dns_lookup_status  status:8;
        unsigned long           failed;         /* Mask of addrs that failed locally/ICMP */
        unsigned long           responded;      /* Mask of addrs that responded */
-       struct {
-               struct sockaddr_rxrpc   srx;
-       } addrs[] __counted_by(max_addrs);
+       struct afs_address      addrs[] __counted_by(max_addrs);
 #define AFS_MAX_ADDRESSES ((unsigned int)(sizeof(unsigned long) * 8))
 };
 
@@ -420,7 +423,7 @@ struct afs_vlserver {
        atomic_t                probe_outstanding;
        spinlock_t              probe_lock;
        struct {
-               unsigned int    rtt;            /* RTT in uS */
+               unsigned int    rtt;            /* Best RTT in uS (or UINT_MAX) */
                u32             abort_code;
                short           error;
                unsigned short  flags;
@@ -537,7 +540,7 @@ struct afs_server {
        atomic_t                probe_outstanding;
        spinlock_t              probe_lock;
        struct {
-               unsigned int    rtt;            /* RTT in uS */
+               unsigned int    rtt;            /* Best RTT in uS (or UINT_MAX) */
                u32             abort_code;
                short           error;
                bool            responded:1;
@@ -964,9 +967,7 @@ static inline struct afs_addr_list *afs_get_addrlist(struct afs_addr_list *alist
                refcount_inc(&alist->usage);
        return alist;
 }
-extern struct afs_addr_list *afs_alloc_addrlist(unsigned int,
-                                               unsigned short,
-                                               unsigned short);
+extern struct afs_addr_list *afs_alloc_addrlist(unsigned int nr, u16 service_id);
 extern void afs_put_addrlist(struct afs_addr_list *);
 extern struct afs_vlserver_list *afs_parse_text_addrs(struct afs_net *,
                                                      const char *, size_t, char,
@@ -977,8 +978,10 @@ extern struct afs_vlserver_list *afs_dns_query(struct afs_cell *, time64_t *);
 extern bool afs_iterate_addresses(struct afs_addr_cursor *);
 extern int afs_end_cursor(struct afs_addr_cursor *);
 
-extern void afs_merge_fs_addr4(struct afs_addr_list *, __be32, u16);
-extern void afs_merge_fs_addr6(struct afs_addr_list *, __be32 *, u16);
+extern int afs_merge_fs_addr4(struct afs_net *net, struct afs_addr_list *addr,
+                             __be32 xdr, u16 port);
+extern int afs_merge_fs_addr6(struct afs_net *net, struct afs_addr_list *addr,
+                             __be32 *xdr, u16 port);
 
 /*
  * callback.c
@@ -1405,8 +1408,7 @@ extern void __exit afs_clean_up_permit_cache(void);
  */
 extern spinlock_t afs_server_peer_lock;
 
-extern struct afs_server *afs_find_server(struct afs_net *,
-                                         const struct sockaddr_rxrpc *);
+extern struct afs_server *afs_find_server(struct afs_net *, const struct rxrpc_peer *);
 extern struct afs_server *afs_find_server_by_uuid(struct afs_net *, const uuid_t *);
 extern struct afs_server *afs_lookup_server(struct afs_cell *, struct key *, const uuid_t *, u32);
 extern struct afs_server *afs_get_server(struct afs_server *, enum afs_server_trace);
index ab9cd98..8a65a06 100644 (file)
@@ -307,7 +307,7 @@ static int afs_proc_cell_vlservers_show(struct seq_file *m, void *v)
                for (i = 0; i < alist->nr_addrs; i++)
                        seq_printf(m, " %c %pISpc\n",
                                   alist->preferred == i ? '>' : '-',
-                                  &alist->addrs[i].srx.transport);
+                                  rxrpc_kernel_remote_addr(alist->addrs[i].peer));
        }
        seq_printf(m, " info: fl=%lx rtt=%d\n", vlserver->flags, vlserver->rtt);
        seq_printf(m, " probe: fl=%x e=%d ac=%d out=%d\n",
@@ -398,9 +398,10 @@ static int afs_proc_servers_show(struct seq_file *m, void *v)
        seq_printf(m, "  - ALIST v=%u rsp=%lx f=%lx\n",
                   alist->version, alist->responded, alist->failed);
        for (i = 0; i < alist->nr_addrs; i++)
-               seq_printf(m, "    [%x] %pISpc%s\n",
-                          i, &alist->addrs[i].srx.transport,
-                          alist->preferred == i ? "*" : "");
+               seq_printf(m, "    [%x] %pISpc%s rtt=%d\n",
+                          i, rxrpc_kernel_remote_addr(alist->addrs[i].peer),
+                          alist->preferred == i ? "*" : "",
+                          rxrpc_kernel_get_srtt(alist->addrs[i].peer));
        return 0;
 }
 
index 46081e5..59aed7a 100644 (file)
@@ -113,7 +113,7 @@ bool afs_select_fileserver(struct afs_operation *op)
        struct afs_server *server;
        struct afs_vnode *vnode = op->file[0].vnode;
        struct afs_error e;
-       u32 rtt;
+       unsigned int rtt;
        int error = op->ac.error, i;
 
        _enter("%lx[%d],%lx[%d],%d,%d",
@@ -420,7 +420,7 @@ pick_server:
        }
 
        op->index = -1;
-       rtt = U32_MAX;
+       rtt = UINT_MAX;
        for (i = 0; i < op->server_list->nr_servers; i++) {
                struct afs_server *s = op->server_list->servers[i].server;
 
@@ -488,7 +488,7 @@ iterate_address:
 
        _debug("address [%u] %u/%u %pISp",
               op->index, op->ac.index, op->ac.alist->nr_addrs,
-              &op->ac.alist->addrs[op->ac.index].srx.transport);
+              rxrpc_kernel_remote_addr(op->ac.alist->addrs[op->ac.index].peer));
 
        _leave(" = t");
        return true;
index 1813171..2603db0 100644 (file)
@@ -296,7 +296,8 @@ static void afs_notify_end_request_tx(struct sock *sock,
  */
 void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
 {
-       struct sockaddr_rxrpc *srx = &ac->alist->addrs[ac->index].srx;
+       struct afs_address *addr = &ac->alist->addrs[ac->index];
+       struct rxrpc_peer *peer = addr->peer;
        struct rxrpc_call *rxcall;
        struct msghdr msg;
        struct kvec iov[1];
@@ -304,7 +305,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
        s64 tx_total_len;
        int ret;
 
-       _enter(",{%pISp},", &srx->transport);
+       _enter(",{%pISp},", rxrpc_kernel_remote_addr(addr->peer));
 
        ASSERT(call->type != NULL);
        ASSERT(call->type->name != NULL);
@@ -333,7 +334,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
        }
 
        /* create a call */
-       rxcall = rxrpc_kernel_begin_call(call->net->socket, srx, call->key,
+       rxcall = rxrpc_kernel_begin_call(call->net->socket, peer, call->key,
                                         (unsigned long)call,
                                         tx_total_len,
                                         call->max_lifespan,
@@ -341,6 +342,7 @@ void afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, gfp_t gfp)
                                         (call->async ?
                                          afs_wake_up_async_call :
                                          afs_wake_up_call_waiter),
+                                        addr->service_id,
                                         call->upgrade,
                                         (call->intr ? RXRPC_PREINTERRUPTIBLE :
                                          RXRPC_UNINTERRUPTIBLE),
@@ -461,7 +463,7 @@ static void afs_log_error(struct afs_call *call, s32 remote_abort)
                max = m + 1;
                pr_notice("kAFS: Peer reported %s failure on %s [%pISp]\n",
                          msg, call->type->name,
-                         &call->alist->addrs[call->addr_ix].srx.transport);
+                         rxrpc_kernel_remote_addr(call->alist->addrs[call->addr_ix].peer));
        }
 }
 
index b8e2d21..5b5fa94 100644 (file)
@@ -21,13 +21,12 @@ static void __afs_put_server(struct afs_net *, struct afs_server *);
 /*
  * Find a server by one of its addresses.
  */
-struct afs_server *afs_find_server(struct afs_net *net,
-                                  const struct sockaddr_rxrpc *srx)
+struct afs_server *afs_find_server(struct afs_net *net, const struct rxrpc_peer *peer)
 {
        const struct afs_addr_list *alist;
        struct afs_server *server = NULL;
        unsigned int i;
-       int seq = 1, diff;
+       int seq = 1;
 
        rcu_read_lock();
 
@@ -38,37 +37,11 @@ struct afs_server *afs_find_server(struct afs_net *net,
                seq++; /* 2 on the 1st/lockless path, otherwise odd */
                read_seqbegin_or_lock(&net->fs_addr_lock, &seq);
 
-               if (srx->transport.family == AF_INET6) {
-                       const struct sockaddr_in6 *a = &srx->transport.sin6, *b;
-                       hlist_for_each_entry_rcu(server, &net->fs_addresses6, addr6_link) {
-                               alist = rcu_dereference(server->addresses);
-                               for (i = alist->nr_ipv4; i < alist->nr_addrs; i++) {
-                                       b = &alist->addrs[i].srx.transport.sin6;
-                                       diff = ((u16 __force)a->sin6_port -
-                                               (u16 __force)b->sin6_port);
-                                       if (diff == 0)
-                                               diff = memcmp(&a->sin6_addr,
-                                                             &b->sin6_addr,
-                                                             sizeof(struct in6_addr));
-                                       if (diff == 0)
-                                               goto found;
-                               }
-                       }
-               } else {
-                       const struct sockaddr_in *a = &srx->transport.sin, *b;
-                       hlist_for_each_entry_rcu(server, &net->fs_addresses4, addr4_link) {
-                               alist = rcu_dereference(server->addresses);
-                               for (i = 0; i < alist->nr_ipv4; i++) {
-                                       b = &alist->addrs[i].srx.transport.sin;
-                                       diff = ((u16 __force)a->sin_port -
-                                               (u16 __force)b->sin_port);
-                                       if (diff == 0)
-                                               diff = ((u32 __force)a->sin_addr.s_addr -
-                                                       (u32 __force)b->sin_addr.s_addr);
-                                       if (diff == 0)
-                                               goto found;
-                               }
-                       }
+               hlist_for_each_entry_rcu(server, &net->fs_addresses6, addr6_link) {
+                       alist = rcu_dereference(server->addresses);
+                       for (i = 0; i < alist->nr_addrs; i++)
+                               if (alist->addrs[i].peer == peer)
+                                       goto found;
                }
 
                server = NULL;
index d3c0df7..6fdf9f1 100644 (file)
@@ -32,55 +32,6 @@ static struct afs_volume *afs_sample_volume(struct afs_cell *cell, struct key *k
        return volume;
 }
 
-/*
- * Compare two addresses.
- */
-static int afs_compare_addrs(const struct sockaddr_rxrpc *srx_a,
-                            const struct sockaddr_rxrpc *srx_b)
-{
-       short port_a, port_b;
-       int addr_a, addr_b, diff;
-
-       diff = (short)srx_a->transport_type - (short)srx_b->transport_type;
-       if (diff)
-               goto out;
-
-       switch (srx_a->transport_type) {
-       case AF_INET: {
-               const struct sockaddr_in *a = &srx_a->transport.sin;
-               const struct sockaddr_in *b = &srx_b->transport.sin;
-               addr_a = ntohl(a->sin_addr.s_addr);
-               addr_b = ntohl(b->sin_addr.s_addr);
-               diff = addr_a - addr_b;
-               if (diff == 0) {
-                       port_a = ntohs(a->sin_port);
-                       port_b = ntohs(b->sin_port);
-                       diff = port_a - port_b;
-               }
-               break;
-       }
-
-       case AF_INET6: {
-               const struct sockaddr_in6 *a = &srx_a->transport.sin6;
-               const struct sockaddr_in6 *b = &srx_b->transport.sin6;
-               diff = memcmp(&a->sin6_addr, &b->sin6_addr, 16);
-               if (diff == 0) {
-                       port_a = ntohs(a->sin6_port);
-                       port_b = ntohs(b->sin6_port);
-                       diff = port_a - port_b;
-               }
-               break;
-       }
-
-       default:
-               WARN_ON(1);
-               diff = 1;
-       }
-
-out:
-       return diff;
-}
-
 /*
  * Compare the address lists of a pair of fileservers.
  */
@@ -94,9 +45,9 @@ static int afs_compare_fs_alists(const struct afs_server *server_a,
        lb = rcu_dereference(server_b->addresses);
 
        while (a < la->nr_addrs && b < lb->nr_addrs) {
-               const struct sockaddr_rxrpc *srx_a = &la->addrs[a].srx;
-               const struct sockaddr_rxrpc *srx_b = &lb->addrs[b].srx;
-               int diff = afs_compare_addrs(srx_a, srx_b);
+               unsigned long pa = (unsigned long)la->addrs[a].peer;
+               unsigned long pb = (unsigned long)lb->addrs[b].peer;
+               long diff = pa - pb;
 
                if (diff < 0) {
                        a++;
index acc4821..ba89140 100644 (file)
@@ -83,14 +83,15 @@ static u16 afs_extract_le16(const u8 **_b)
 /*
  * Build a VL server address list from a DNS queried server list.
  */
-static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end,
+static struct afs_addr_list *afs_extract_vl_addrs(struct afs_net *net,
+                                                 const u8 **_b, const u8 *end,
                                                  u8 nr_addrs, u16 port)
 {
        struct afs_addr_list *alist;
        const u8 *b = *_b;
        int ret = -EINVAL;
 
-       alist = afs_alloc_addrlist(nr_addrs, VL_SERVICE, port);
+       alist = afs_alloc_addrlist(nr_addrs, VL_SERVICE);
        if (!alist)
                return ERR_PTR(-ENOMEM);
        if (nr_addrs == 0)
@@ -109,7 +110,9 @@ static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end,
                                goto error;
                        }
                        memcpy(x, b, 4);
-                       afs_merge_fs_addr4(alist, x[0], port);
+                       ret = afs_merge_fs_addr4(net, alist, x[0], port);
+                       if (ret < 0)
+                               goto error;
                        b += 4;
                        break;
 
@@ -119,7 +122,9 @@ static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end,
                                goto error;
                        }
                        memcpy(x, b, 16);
-                       afs_merge_fs_addr6(alist, x, port);
+                       ret = afs_merge_fs_addr6(net, alist, x, port);
+                       if (ret < 0)
+                               goto error;
                        b += 16;
                        break;
 
@@ -247,7 +252,7 @@ struct afs_vlserver_list *afs_extract_vlserver_list(struct afs_cell *cell,
                /* Extract the addresses - note that we can't skip this as we
                 * have to advance the payload pointer.
                 */
-               addrs = afs_extract_vl_addrs(&b, end, bs.nr_addrs, bs.port);
+               addrs = afs_extract_vl_addrs(cell->net, &b, end, bs.nr_addrs, bs.port);
                if (IS_ERR(addrs)) {
                        ret = PTR_ERR(addrs);
                        goto error_2;
index bdd9372..9551aef 100644 (file)
@@ -48,6 +48,7 @@ void afs_vlserver_probe_result(struct afs_call *call)
 {
        struct afs_addr_list *alist = call->alist;
        struct afs_vlserver *server = call->vlserver;
+       struct afs_address *addr = &alist->addrs[call->addr_ix];
        unsigned int server_index = call->server_index;
        unsigned int rtt_us = 0;
        unsigned int index = call->addr_ix;
@@ -106,16 +107,16 @@ responded:
        if (call->service_id == YFS_VL_SERVICE) {
                server->probe.flags |= AFS_VLSERVER_PROBE_IS_YFS;
                set_bit(AFS_VLSERVER_FL_IS_YFS, &server->flags);
-               alist->addrs[index].srx.srx_service = call->service_id;
+               addr->service_id = call->service_id;
        } else {
                server->probe.flags |= AFS_VLSERVER_PROBE_NOT_YFS;
                if (!(server->probe.flags & AFS_VLSERVER_PROBE_IS_YFS)) {
                        clear_bit(AFS_VLSERVER_FL_IS_YFS, &server->flags);
-                       alist->addrs[index].srx.srx_service = call->service_id;
+                       addr->service_id = call->service_id;
                }
        }
 
-       rxrpc_kernel_get_srtt(call->net->socket, call->rxcall, &rtt_us);
+       rtt_us = rxrpc_kernel_get_srtt(addr->peer);
        if (rtt_us < server->probe.rtt) {
                server->probe.rtt = rtt_us;
                server->rtt = rtt_us;
@@ -130,8 +131,9 @@ responded:
 out:
        spin_unlock(&server->probe_lock);
 
-       _debug("probe [%u][%u] %pISpc rtt=%u ret=%d",
-              server_index, index, &alist->addrs[index].srx.transport, rtt_us, ret);
+       _debug("probe [%u][%u] %pISpc rtt=%d ret=%d",
+              server_index, index, rxrpc_kernel_remote_addr(addr->peer),
+              rtt_us, ret);
 
        afs_done_one_vl_probe(server, have_result);
 }
index e52b9d4..f8f255c 100644 (file)
@@ -92,7 +92,7 @@ bool afs_select_vlserver(struct afs_vl_cursor *vc)
        struct afs_addr_list *alist;
        struct afs_vlserver *vlserver;
        struct afs_error e;
-       u32 rtt;
+       unsigned int rtt;
        int error = vc->ac.error, i;
 
        _enter("%lx[%d],%lx[%d],%d,%d",
@@ -194,7 +194,7 @@ pick_server:
                goto selected_server;
 
        vc->index = -1;
-       rtt = U32_MAX;
+       rtt = UINT_MAX;
        for (i = 0; i < vc->server_list->nr_servers; i++) {
                struct afs_vlserver *s = vc->server_list->servers[i].server;
 
@@ -249,7 +249,7 @@ iterate_address:
 
        _debug("VL address %d/%d", vc->ac.index, vc->ac.alist->nr_addrs);
 
-       _leave(" = t %pISpc", &vc->ac.alist->addrs[vc->ac.index].srx.transport);
+       _leave(" = t %pISpc", rxrpc_kernel_remote_addr(vc->ac.alist->addrs[vc->ac.index].peer));
        return true;
 
 next_server:
index 00fca3c..41e7932 100644 (file)
@@ -208,7 +208,7 @@ static int afs_deliver_vl_get_addrs_u(struct afs_call *call)
                count           = ntohl(*bp);
 
                nentries = min(nentries, count);
-               alist = afs_alloc_addrlist(nentries, FS_SERVICE, AFS_FS_PORT);
+               alist = afs_alloc_addrlist(nentries, FS_SERVICE);
                if (!alist)
                        return -ENOMEM;
                alist->version = uniquifier;
@@ -230,9 +230,13 @@ static int afs_deliver_vl_get_addrs_u(struct afs_call *call)
                alist = call->ret_alist;
                bp = call->buffer;
                count = min(call->count, 4U);
-               for (i = 0; i < count; i++)
-                       if (alist->nr_addrs < call->count2)
-                               afs_merge_fs_addr4(alist, *bp++, AFS_FS_PORT);
+               for (i = 0; i < count; i++) {
+                       if (alist->nr_addrs < call->count2) {
+                               ret = afs_merge_fs_addr4(call->net, alist, *bp++, AFS_FS_PORT);
+                               if (ret < 0)
+                                       return ret;
+                       }
+               }
 
                call->count -= count;
                if (call->count > 0)
@@ -450,7 +454,7 @@ static int afs_deliver_yfsvl_get_endpoints(struct afs_call *call)
                if (call->count > YFS_MAXENDPOINTS)
                        return afs_protocol_error(call, afs_eproto_yvl_fsendpt_num);
 
-               alist = afs_alloc_addrlist(call->count, FS_SERVICE, AFS_FS_PORT);
+               alist = afs_alloc_addrlist(call->count, FS_SERVICE);
                if (!alist)
                        return -ENOMEM;
                alist->version = uniquifier;
@@ -488,14 +492,18 @@ static int afs_deliver_yfsvl_get_endpoints(struct afs_call *call)
                        if (ntohl(bp[0]) != sizeof(__be32) * 2)
                                return afs_protocol_error(
                                        call, afs_eproto_yvl_fsendpt4_len);
-                       afs_merge_fs_addr4(alist, bp[1], ntohl(bp[2]));
+                       ret = afs_merge_fs_addr4(call->net, alist, bp[1], ntohl(bp[2]));
+                       if (ret < 0)
+                               return ret;
                        bp += 3;
                        break;
                case YFS_ENDPOINT_IPV6:
                        if (ntohl(bp[0]) != sizeof(__be32) * 5)
                                return afs_protocol_error(
                                        call, afs_eproto_yvl_fsendpt6_len);
-                       afs_merge_fs_addr6(alist, bp + 1, ntohl(bp[5]));
+                       ret = afs_merge_fs_addr6(call->net, alist, bp + 1, ntohl(bp[5]));
+                       if (ret < 0)
+                               return ret;
                        bp += 6;
                        break;
                default:
index 5531dd0..0754c46 100644 (file)
@@ -15,6 +15,7 @@ struct key;
 struct sock;
 struct socket;
 struct rxrpc_call;
+struct rxrpc_peer;
 enum rxrpc_abort_reason;
 
 enum rxrpc_interruptibility {
@@ -41,13 +42,14 @@ void rxrpc_kernel_new_call_notification(struct socket *,
                                        rxrpc_notify_new_call_t,
                                        rxrpc_discard_new_call_t);
 struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
-                                          struct sockaddr_rxrpc *srx,
+                                          struct rxrpc_peer *peer,
                                           struct key *key,
                                           unsigned long user_call_ID,
                                           s64 tx_total_len,
                                           u32 hard_timeout,
                                           gfp_t gfp,
                                           rxrpc_notify_rx_t notify_rx,
+                                          u16 service_id,
                                           bool upgrade,
                                           enum rxrpc_interruptibility interruptibility,
                                           unsigned int debug_id);
@@ -60,9 +62,14 @@ bool rxrpc_kernel_abort_call(struct socket *, struct rxrpc_call *,
                             u32, int, enum rxrpc_abort_reason);
 void rxrpc_kernel_shutdown_call(struct socket *sock, struct rxrpc_call *call);
 void rxrpc_kernel_put_call(struct socket *sock, struct rxrpc_call *call);
-void rxrpc_kernel_get_peer(struct socket *, struct rxrpc_call *,
-                          struct sockaddr_rxrpc *);
-bool rxrpc_kernel_get_srtt(struct socket *, struct rxrpc_call *, u32 *);
+struct rxrpc_peer *rxrpc_kernel_lookup_peer(struct socket *sock,
+                                           struct sockaddr_rxrpc *srx, gfp_t gfp);
+void rxrpc_kernel_put_peer(struct rxrpc_peer *peer);
+struct rxrpc_peer *rxrpc_kernel_get_peer(struct rxrpc_peer *peer);
+struct rxrpc_peer *rxrpc_kernel_get_call_peer(struct socket *sock, struct rxrpc_call *call);
+const struct sockaddr_rxrpc *rxrpc_kernel_remote_srx(const struct rxrpc_peer *peer);
+const struct sockaddr *rxrpc_kernel_remote_addr(const struct rxrpc_peer *peer);
+unsigned int rxrpc_kernel_get_srtt(const struct rxrpc_peer *);
 int rxrpc_kernel_charge_accept(struct socket *, rxrpc_notify_rx_t,
                               rxrpc_user_attach_call_t, unsigned long, gfp_t,
                               unsigned int);
index f7e537f..4c1ef7b 100644 (file)
 #define rxrpc_peer_traces \
        EM(rxrpc_peer_free,                     "FREE        ") \
        EM(rxrpc_peer_get_accept,               "GET accept  ") \
+       EM(rxrpc_peer_get_application,          "GET app     ") \
        EM(rxrpc_peer_get_bundle,               "GET bundle  ") \
+       EM(rxrpc_peer_get_call,                 "GET call    ") \
        EM(rxrpc_peer_get_client_conn,          "GET cln-conn") \
        EM(rxrpc_peer_get_input,                "GET input   ") \
        EM(rxrpc_peer_get_input_error,          "GET inpt-err") \
        EM(rxrpc_peer_get_service_conn,         "GET srv-conn") \
        EM(rxrpc_peer_new_client,               "NEW client  ") \
        EM(rxrpc_peer_new_prealloc,             "NEW prealloc") \
+       EM(rxrpc_peer_put_application,          "PUT app     ") \
        EM(rxrpc_peer_put_bundle,               "PUT bundle  ") \
        EM(rxrpc_peer_put_call,                 "PUT call    ") \
        EM(rxrpc_peer_put_conn,                 "PUT conn    ") \
index fa8aec7..465bfe5 100644 (file)
@@ -258,16 +258,62 @@ static int rxrpc_listen(struct socket *sock, int backlog)
        return ret;
 }
 
+/**
+ * rxrpc_kernel_lookup_peer - Obtain remote transport endpoint for an address
+ * @sock: The socket through which it will be accessed
+ * @srx: The network address
+ * @gfp: Allocation flags
+ *
+ * Lookup or create a remote transport endpoint record for the specified
+ * address and return it with a ref held.
+ */
+struct rxrpc_peer *rxrpc_kernel_lookup_peer(struct socket *sock,
+                                           struct sockaddr_rxrpc *srx, gfp_t gfp)
+{
+       struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
+       int ret;
+
+       ret = rxrpc_validate_address(rx, srx, sizeof(*srx));
+       if (ret < 0)
+               return ERR_PTR(ret);
+
+       return rxrpc_lookup_peer(rx->local, srx, gfp);
+}
+EXPORT_SYMBOL(rxrpc_kernel_lookup_peer);
+
+/**
+ * rxrpc_kernel_get_peer - Get a reference on a peer
+ * @peer: The peer to get a reference on.
+ *
+ * Get a record for the remote peer in a call.
+ */
+struct rxrpc_peer *rxrpc_kernel_get_peer(struct rxrpc_peer *peer)
+{
+       return peer ? rxrpc_get_peer(peer, rxrpc_peer_get_application) : NULL;
+}
+EXPORT_SYMBOL(rxrpc_kernel_get_peer);
+
+/**
+ * rxrpc_kernel_put_peer - Allow a kernel app to drop a peer reference
+ * @peer: The peer to drop a ref on
+ */
+void rxrpc_kernel_put_peer(struct rxrpc_peer *peer)
+{
+       rxrpc_put_peer(peer, rxrpc_peer_put_application);
+}
+EXPORT_SYMBOL(rxrpc_kernel_put_peer);
+
 /**
  * rxrpc_kernel_begin_call - Allow a kernel service to begin a call
  * @sock: The socket on which to make the call
- * @srx: The address of the peer to contact
+ * @peer: The peer to contact
  * @key: The security context to use (defaults to socket setting)
  * @user_call_ID: The ID to use
  * @tx_total_len: Total length of data to transmit during the call (or -1)
  * @hard_timeout: The maximum lifespan of the call in sec
  * @gfp: The allocation constraints
  * @notify_rx: Where to send notifications instead of socket queue
+ * @service_id: The ID of the service to contact
  * @upgrade: Request service upgrade for call
  * @interruptibility: The call is interruptible, or can be canceled.
  * @debug_id: The debug ID for tracing to be assigned to the call
@@ -280,13 +326,14 @@ static int rxrpc_listen(struct socket *sock, int backlog)
  * supplying @srx and @key.
  */
 struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
-                                          struct sockaddr_rxrpc *srx,
+                                          struct rxrpc_peer *peer,
                                           struct key *key,
                                           unsigned long user_call_ID,
                                           s64 tx_total_len,
                                           u32 hard_timeout,
                                           gfp_t gfp,
                                           rxrpc_notify_rx_t notify_rx,
+                                          u16 service_id,
                                           bool upgrade,
                                           enum rxrpc_interruptibility interruptibility,
                                           unsigned int debug_id)
@@ -295,13 +342,11 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
        struct rxrpc_call_params p;
        struct rxrpc_call *call;
        struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
-       int ret;
 
        _enter(",,%x,%lx", key_serial(key), user_call_ID);
 
-       ret = rxrpc_validate_address(rx, srx, sizeof(*srx));
-       if (ret < 0)
-               return ERR_PTR(ret);
+       if (WARN_ON_ONCE(peer->local != rx->local))
+               return ERR_PTR(-EIO);
 
        lock_sock(&rx->sk);
 
@@ -319,12 +364,13 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
 
        memset(&cp, 0, sizeof(cp));
        cp.local                = rx->local;
+       cp.peer                 = peer;
        cp.key                  = key;
        cp.security_level       = rx->min_sec_level;
        cp.exclusive            = false;
        cp.upgrade              = upgrade;
-       cp.service_id           = srx->srx_service;
-       call = rxrpc_new_client_call(rx, &cp, srx, &p, gfp, debug_id);
+       cp.service_id           = service_id;
+       call = rxrpc_new_client_call(rx, &cp, &p, gfp, debug_id);
        /* The socket has been unlocked. */
        if (!IS_ERR(call)) {
                call->notify_rx = notify_rx;
index e8e14c6..8eea7a4 100644 (file)
@@ -364,6 +364,7 @@ struct rxrpc_conn_proto {
 
 struct rxrpc_conn_parameters {
        struct rxrpc_local      *local;         /* Representation of local endpoint */
+       struct rxrpc_peer       *peer;          /* Representation of remote endpoint */
        struct key              *key;           /* Security details */
        bool                    exclusive;      /* T if conn is exclusive */
        bool                    upgrade;        /* T if service ID can be upgraded */
@@ -867,7 +868,6 @@ struct rxrpc_call *rxrpc_find_call_by_user_ID(struct rxrpc_sock *, unsigned long
 struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *, gfp_t, unsigned int);
 struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *,
                                         struct rxrpc_conn_parameters *,
-                                        struct sockaddr_rxrpc *,
                                         struct rxrpc_call_params *, gfp_t,
                                         unsigned int);
 void rxrpc_start_call_timer(struct rxrpc_call *call);
index 773eecd..beea25a 100644 (file)
@@ -193,7 +193,6 @@ struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *rx, gfp_t gfp,
  * Allocate a new client call.
  */
 static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx,
-                                                 struct sockaddr_rxrpc *srx,
                                                  struct rxrpc_conn_parameters *cp,
                                                  struct rxrpc_call_params *p,
                                                  gfp_t gfp,
@@ -211,10 +210,12 @@ static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx,
        now = ktime_get_real();
        call->acks_latest_ts    = now;
        call->cong_tstamp       = now;
-       call->dest_srx          = *srx;
+       call->dest_srx          = cp->peer->srx;
+       call->dest_srx.srx_service = cp->service_id;
        call->interruptibility  = p->interruptibility;
        call->tx_total_len      = p->tx_total_len;
        call->key               = key_get(cp->key);
+       call->peer              = rxrpc_get_peer(cp->peer, rxrpc_peer_get_call);
        call->local             = rxrpc_get_local(cp->local, rxrpc_local_get_call);
        call->security_level    = cp->security_level;
        if (p->kernel)
@@ -306,10 +307,6 @@ static int rxrpc_connect_call(struct rxrpc_call *call, gfp_t gfp)
 
        _enter("{%d,%lx},", call->debug_id, call->user_call_ID);
 
-       call->peer = rxrpc_lookup_peer(local, &call->dest_srx, gfp);
-       if (!call->peer)
-               goto error;
-
        ret = rxrpc_look_up_bundle(call, gfp);
        if (ret < 0)
                goto error;
@@ -334,7 +331,6 @@ error:
  */
 struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx,
                                         struct rxrpc_conn_parameters *cp,
-                                        struct sockaddr_rxrpc *srx,
                                         struct rxrpc_call_params *p,
                                         gfp_t gfp,
                                         unsigned int debug_id)
@@ -349,13 +345,18 @@ struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx,
 
        _enter("%p,%lx", rx, p->user_call_ID);
 
+       if (WARN_ON_ONCE(!cp->peer)) {
+               release_sock(&rx->sk);
+               return ERR_PTR(-EIO);
+       }
+
        limiter = rxrpc_get_call_slot(p, gfp);
        if (!limiter) {
                release_sock(&rx->sk);
                return ERR_PTR(-ERESTARTSYS);
        }
 
-       call = rxrpc_alloc_client_call(rx, srx, cp, p, gfp, debug_id);
+       call = rxrpc_alloc_client_call(rx, cp, p, gfp, debug_id);
        if (IS_ERR(call)) {
                release_sock(&rx->sk);
                up(limiter);
index 8d7a715..49dcda6 100644 (file)
@@ -22,6 +22,8 @@
 #include <net/ip6_route.h>
 #include "ar-internal.h"
 
+static const struct sockaddr_rxrpc rxrpc_null_addr;
+
 /*
  * Hash a peer key.
  */
@@ -457,39 +459,53 @@ void rxrpc_destroy_all_peers(struct rxrpc_net *rxnet)
 }
 
 /**
- * rxrpc_kernel_get_peer - Get the peer address of a call
+ * rxrpc_kernel_get_call_peer - Get the peer address of a call
  * @sock: The socket on which the call is in progress.
  * @call: The call to query
- * @_srx: Where to place the result
  *
- * Get the address of the remote peer in a call.
+ * Get a record for the remote peer in a call.
  */
-void rxrpc_kernel_get_peer(struct socket *sock, struct rxrpc_call *call,
-                          struct sockaddr_rxrpc *_srx)
+struct rxrpc_peer *rxrpc_kernel_get_call_peer(struct socket *sock, struct rxrpc_call *call)
 {
-       *_srx = call->peer->srx;
+       return call->peer;
 }
-EXPORT_SYMBOL(rxrpc_kernel_get_peer);
+EXPORT_SYMBOL(rxrpc_kernel_get_call_peer);
 
 /**
  * rxrpc_kernel_get_srtt - Get a call's peer smoothed RTT
- * @sock: The socket on which the call is in progress.
- * @call: The call to query
- * @_srtt: Where to store the SRTT value.
+ * @peer: The peer to query
  *
- * Get the call's peer smoothed RTT in uS.
+ * Get the call's peer smoothed RTT in uS or UINT_MAX if we have no samples.
  */
-bool rxrpc_kernel_get_srtt(struct socket *sock, struct rxrpc_call *call,
-                          u32 *_srtt)
+unsigned int rxrpc_kernel_get_srtt(const struct rxrpc_peer *peer)
 {
-       struct rxrpc_peer *peer = call->peer;
+       return peer->rtt_count > 0 ? peer->srtt_us >> 3 : UINT_MAX;
+}
+EXPORT_SYMBOL(rxrpc_kernel_get_srtt);
 
-       if (peer->rtt_count == 0) {
-               *_srtt = 1000000; /* 1S */
-               return false;
-       }
+/**
+ * rxrpc_kernel_remote_srx - Get the address of a peer
+ * @peer: The peer to query
+ *
+ * Get a pointer to the address from a peer record.  The caller is responsible
+ * for making sure that the address is not deallocated.
+ */
+const struct sockaddr_rxrpc *rxrpc_kernel_remote_srx(const struct rxrpc_peer *peer)
+{
+       return peer ? &peer->srx : &rxrpc_null_addr;
+}
+EXPORT_SYMBOL(rxrpc_kernel_remote_srx);
 
-       *_srtt = call->peer->srtt_us >> 3;
-       return true;
+/**
+ * rxrpc_kernel_remote_addr - Get the peer transport address of a call
+ * @peer: The peer to query
+ *
+ * Get a pointer to the transport address from a peer record.  The caller is
+ * responsible for making sure that the address is not deallocated.
+ */
+const struct sockaddr *rxrpc_kernel_remote_addr(const struct rxrpc_peer *peer)
+{
+       return (const struct sockaddr *)
+               (peer ? &peer->srx.transport : &rxrpc_null_addr.transport);
 }
-EXPORT_SYMBOL(rxrpc_kernel_get_srtt);
+EXPORT_SYMBOL(rxrpc_kernel_remote_addr);
index 8e0b947..5677d56 100644 (file)
@@ -572,6 +572,7 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg,
        __acquires(&call->user_mutex)
 {
        struct rxrpc_conn_parameters cp;
+       struct rxrpc_peer *peer;
        struct rxrpc_call *call;
        struct key *key;
 
@@ -584,21 +585,29 @@ rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg,
                return ERR_PTR(-EDESTADDRREQ);
        }
 
+       peer = rxrpc_lookup_peer(rx->local, srx, GFP_KERNEL);
+       if (!peer) {
+               release_sock(&rx->sk);
+               return ERR_PTR(-ENOMEM);
+       }
+
        key = rx->key;
        if (key && !rx->key->payload.data[0])
                key = NULL;
 
        memset(&cp, 0, sizeof(cp));
        cp.local                = rx->local;
+       cp.peer                 = peer;
        cp.key                  = rx->key;
        cp.security_level       = rx->min_sec_level;
        cp.exclusive            = rx->exclusive | p->exclusive;
        cp.upgrade              = p->upgrade;
        cp.service_id           = srx->srx_service;
-       call = rxrpc_new_client_call(rx, &cp, srx, &p->call, GFP_KERNEL,
+       call = rxrpc_new_client_call(rx, &cp, &p->call, GFP_KERNEL,
                                     atomic_inc_return(&rxrpc_debug_id));
        /* The socket is now unlocked */
 
+       rxrpc_put_peer(peer, rxrpc_peer_put_application);
        _leave(" = %p\n", call);
        return call;
 }