x86/alternative: Handle Jcc __x86_indirect_thunk_\reg
authorPeter Zijlstra <peterz@infradead.org>
Tue, 26 Oct 2021 12:01:43 +0000 (14:01 +0200)
committerPeter Zijlstra <peterz@infradead.org>
Thu, 28 Oct 2021 21:25:28 +0000 (23:25 +0200)
Handle the rare cases where the compiler (clang) does an indirect
conditional tail-call using:

  Jcc __x86_indirect_thunk_\reg

For the !RETPOLINE case this can be rewritten to fit the original (6
byte) instruction like:

  Jncc.d8 1f
  JMP *%\reg
  NOP
1:

Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Reviewed-by: Borislav Petkov <bp@suse.de>
Acked-by: Josh Poimboeuf <jpoimboe@redhat.com>
Tested-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/r/20211026120310.296470217@infradead.org
arch/x86/kernel/alternative.c

index 5df4034..1dea2f6 100644 (file)
@@ -393,7 +393,8 @@ static int emit_indirect(int op, int reg, u8 *bytes)
 static int patch_retpoline(void *addr, struct insn *insn, u8 *bytes)
 {
        retpoline_thunk_t *target;
-       int reg, i = 0;
+       int reg, ret, i = 0;
+       u8 op, cc;
 
        target = addr + insn->length + insn->immediate.value;
        reg = target - __x86_indirect_thunk_array;
@@ -407,9 +408,36 @@ static int patch_retpoline(void *addr, struct insn *insn, u8 *bytes)
        if (cpu_feature_enabled(X86_FEATURE_RETPOLINE))
                return -1;
 
-       i = emit_indirect(insn->opcode.bytes[0], reg, bytes);
-       if (i < 0)
-               return i;
+       op = insn->opcode.bytes[0];
+
+       /*
+        * Convert:
+        *
+        *   Jcc.d32 __x86_indirect_thunk_\reg
+        *
+        * into:
+        *
+        *   Jncc.d8 1f
+        *   JMP *%\reg
+        *   NOP
+        * 1:
+        */
+       /* Jcc.d32 second opcode byte is in the range: 0x80-0x8f */
+       if (op == 0x0f && (insn->opcode.bytes[1] & 0xf0) == 0x80) {
+               cc = insn->opcode.bytes[1] & 0xf;
+               cc ^= 1; /* invert condition */
+
+               bytes[i++] = 0x70 + cc;        /* Jcc.d8 */
+               bytes[i++] = insn->length - 2; /* sizeof(Jcc.d8) == 2 */
+
+               /* Continue as if: JMP.d32 __x86_indirect_thunk_\reg */
+               op = JMP32_INSN_OPCODE;
+       }
+
+       ret = emit_indirect(op, reg, bytes + i);
+       if (ret < 0)
+               return ret;
+       i += ret;
 
        for (; i < insn->length;)
                bytes[i++] = BYTES_NOP1;
@@ -443,6 +471,10 @@ void __init_or_module noinline apply_retpolines(s32 *start, s32 *end)
                case JMP32_INSN_OPCODE:
                        break;
 
+               case 0x0f: /* escape */
+                       if (op2 >= 0x80 && op2 <= 0x8f)
+                               break;
+                       fallthrough;
                default:
                        WARN_ON_ONCE(1);
                        continue;