Merge tag 'block-5.15-2021-09-05' of git://git.kernel.dk/linux-block
[linux-2.6-microblaze.git] / kernel / tracepoint.c
index 976bf8c..64ea283 100644 (file)
 #include <linux/sched/task.h>
 #include <linux/static_key.h>
 
+enum tp_func_state {
+       TP_FUNC_0,
+       TP_FUNC_1,
+       TP_FUNC_2,
+       TP_FUNC_N,
+};
+
 extern tracepoint_ptr_t __start___tracepoints_ptrs[];
 extern tracepoint_ptr_t __stop___tracepoints_ptrs[];
 
 DEFINE_SRCU(tracepoint_srcu);
 EXPORT_SYMBOL_GPL(tracepoint_srcu);
 
+enum tp_transition_sync {
+       TP_TRANSITION_SYNC_1_0_1,
+       TP_TRANSITION_SYNC_N_2_1,
+
+       _NR_TP_TRANSITION_SYNC,
+};
+
+struct tp_transition_snapshot {
+       unsigned long rcu;
+       unsigned long srcu;
+       bool ongoing;
+};
+
+/* Protected by tracepoints_mutex */
+static struct tp_transition_snapshot tp_transition_snapshot[_NR_TP_TRANSITION_SYNC];
+
+static void tp_rcu_get_state(enum tp_transition_sync sync)
+{
+       struct tp_transition_snapshot *snapshot = &tp_transition_snapshot[sync];
+
+       /* Keep the latest get_state snapshot. */
+       snapshot->rcu = get_state_synchronize_rcu();
+       snapshot->srcu = start_poll_synchronize_srcu(&tracepoint_srcu);
+       snapshot->ongoing = true;
+}
+
+static void tp_rcu_cond_sync(enum tp_transition_sync sync)
+{
+       struct tp_transition_snapshot *snapshot = &tp_transition_snapshot[sync];
+
+       if (!snapshot->ongoing)
+               return;
+       cond_synchronize_rcu(snapshot->rcu);
+       if (!poll_state_synchronize_srcu(&tracepoint_srcu, snapshot->srcu))
+               synchronize_srcu(&tracepoint_srcu);
+       snapshot->ongoing = false;
+}
+
 /* Set to 1 to enable tracepoint debug output */
 static const int tracepoint_debug;
 
@@ -246,26 +291,29 @@ static void *func_remove(struct tracepoint_func **funcs,
        return old;
 }
 
-static void tracepoint_update_call(struct tracepoint *tp, struct tracepoint_func *tp_funcs, bool sync)
+/*
+ * Count the number of functions (enum tp_func_state) in a tp_funcs array.
+ */
+static enum tp_func_state nr_func_state(const struct tracepoint_func *tp_funcs)
+{
+       if (!tp_funcs)
+               return TP_FUNC_0;
+       if (!tp_funcs[1].func)
+               return TP_FUNC_1;
+       if (!tp_funcs[2].func)
+               return TP_FUNC_2;
+       return TP_FUNC_N;       /* 3 or more */
+}
+
+static void tracepoint_update_call(struct tracepoint *tp, struct tracepoint_func *tp_funcs)
 {
        void *func = tp->iterator;
 
        /* Synthetic events do not have static call sites */
        if (!tp->static_call_key)
                return;
-
-       if (!tp_funcs[1].func) {
+       if (nr_func_state(tp_funcs) == TP_FUNC_1)
                func = tp_funcs[0].func;
-               /*
-                * If going from the iterator back to a single caller,
-                * we need to synchronize with __DO_TRACE to make sure
-                * that the data passed to the callback is the one that
-                * belongs to that callback.
-                */
-               if (sync)
-                       tracepoint_synchronize_unregister();
-       }
-
        __static_call_update(tp->static_call_key, tp->static_call_tramp, func);
 }
 
@@ -299,9 +347,41 @@ static int tracepoint_add_func(struct tracepoint *tp,
         * a pointer to it.  This array is referenced by __DO_TRACE from
         * include/linux/tracepoint.h using rcu_dereference_sched().
         */
-       rcu_assign_pointer(tp->funcs, tp_funcs);
-       tracepoint_update_call(tp, tp_funcs, false);
-       static_key_enable(&tp->key);
+       switch (nr_func_state(tp_funcs)) {
+       case TP_FUNC_1:         /* 0->1 */
+               /*
+                * Make sure new static func never uses old data after a
+                * 1->0->1 transition sequence.
+                */
+               tp_rcu_cond_sync(TP_TRANSITION_SYNC_1_0_1);
+               /* Set static call to first function */
+               tracepoint_update_call(tp, tp_funcs);
+               /* Both iterator and static call handle NULL tp->funcs */
+               rcu_assign_pointer(tp->funcs, tp_funcs);
+               static_key_enable(&tp->key);
+               break;
+       case TP_FUNC_2:         /* 1->2 */
+               /* Set iterator static call */
+               tracepoint_update_call(tp, tp_funcs);
+               /*
+                * Iterator callback installed before updating tp->funcs.
+                * Requires ordering between RCU assign/dereference and
+                * static call update/call.
+                */
+               fallthrough;
+       case TP_FUNC_N:         /* N->N+1 (N>1) */
+               rcu_assign_pointer(tp->funcs, tp_funcs);
+               /*
+                * Make sure static func never uses incorrect data after a
+                * N->...->2->1 (N>1) transition sequence.
+                */
+               if (tp_funcs[0].data != old[0].data)
+                       tp_rcu_get_state(TP_TRANSITION_SYNC_N_2_1);
+               break;
+       default:
+               WARN_ON_ONCE(1);
+               break;
+       }
 
        release_probes(old);
        return 0;
@@ -328,17 +408,52 @@ static int tracepoint_remove_func(struct tracepoint *tp,
                /* Failed allocating new tp_funcs, replaced func with stub */
                return 0;
 
-       if (!tp_funcs) {
+       switch (nr_func_state(tp_funcs)) {
+       case TP_FUNC_0:         /* 1->0 */
                /* Removed last function */
                if (tp->unregfunc && static_key_enabled(&tp->key))
                        tp->unregfunc();
 
                static_key_disable(&tp->key);
+               /* Set iterator static call */
+               tracepoint_update_call(tp, tp_funcs);
+               /* Both iterator and static call handle NULL tp->funcs */
+               rcu_assign_pointer(tp->funcs, NULL);
+               /*
+                * Make sure new static func never uses old data after a
+                * 1->0->1 transition sequence.
+                */
+               tp_rcu_get_state(TP_TRANSITION_SYNC_1_0_1);
+               break;
+       case TP_FUNC_1:         /* 2->1 */
                rcu_assign_pointer(tp->funcs, tp_funcs);
-       } else {
+               /*
+                * Make sure static func never uses incorrect data after a
+                * N->...->2->1 (N>2) transition sequence. If the first
+                * element's data has changed, then force the synchronization
+                * to prevent current readers that have loaded the old data
+                * from calling the new function.
+                */
+               if (tp_funcs[0].data != old[0].data)
+                       tp_rcu_get_state(TP_TRANSITION_SYNC_N_2_1);
+               tp_rcu_cond_sync(TP_TRANSITION_SYNC_N_2_1);
+               /* Set static call to first function */
+               tracepoint_update_call(tp, tp_funcs);
+               break;
+       case TP_FUNC_2:         /* N->N-1 (N>2) */
+               fallthrough;
+       case TP_FUNC_N:
                rcu_assign_pointer(tp->funcs, tp_funcs);
-               tracepoint_update_call(tp, tp_funcs,
-                                      tp_funcs[0].func != old[0].func);
+               /*
+                * Make sure static func never uses incorrect data after a
+                * N->...->2->1 (N>2) transition sequence.
+                */
+               if (tp_funcs[0].data != old[0].data)
+                       tp_rcu_get_state(TP_TRANSITION_SYNC_N_2_1);
+               break;
+       default:
+               WARN_ON_ONCE(1);
+               break;
        }
        release_probes(old);
        return 0;
@@ -462,7 +577,7 @@ bool trace_module_has_bad_taint(struct module *mod)
 static BLOCKING_NOTIFIER_HEAD(tracepoint_notify_list);
 
 /**
- * register_tracepoint_notifier - register tracepoint coming/going notifier
+ * register_tracepoint_module_notifier - register tracepoint coming/going notifier
  * @nb: notifier block
  *
  * Notifiers registered with this function are called on module
@@ -488,7 +603,7 @@ end:
 EXPORT_SYMBOL_GPL(register_tracepoint_module_notifier);
 
 /**
- * unregister_tracepoint_notifier - unregister tracepoint coming/going notifier
+ * unregister_tracepoint_module_notifier - unregister tracepoint coming/going notifier
  * @nb: notifier block
  *
  * The notifier block callback should expect a "struct tp_module" data