Merge branches 'for-next/sysreg', 'for-next/sme', 'for-next/kselftest', 'for-next...
[linux-2.6-microblaze.git] / tools / testing / selftests / arm64 / abi / syscall-abi.c
index dd7ebe5..18cc123 100644 (file)
 
 #include "syscall-abi.h"
 
-#define NUM_VL ((SVE_VQ_MAX - SVE_VQ_MIN) + 1)
-
 static int default_sme_vl;
 
+static int sve_vl_count;
+static unsigned int sve_vls[SVE_VQ_MAX];
+static int sme_vl_count;
+static unsigned int sme_vls[SVE_VQ_MAX];
+
 extern void do_syscall(int sve_vl, int sme_vl);
 
 static void fill_random(void *buf, size_t size)
@@ -83,6 +86,7 @@ static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t s
 #define NUM_FPR 32
 uint64_t fpr_in[NUM_FPR * 2];
 uint64_t fpr_out[NUM_FPR * 2];
+uint64_t fpr_zero[NUM_FPR * 2];
 
 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
                      uint64_t svcr)
@@ -97,7 +101,7 @@ static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
        int errors = 0;
        int i;
 
-       if (!sve_vl) {
+       if (!sve_vl && !(svcr & SVCR_SM_MASK)) {
                for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
                        if (fpr_in[i] != fpr_out[i]) {
                                ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
@@ -109,6 +113,18 @@ static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
                }
        }
 
+       /*
+        * In streaming mode the whole register set should be cleared
+        * by the transition out of streaming mode.
+        */
+       if (svcr & SVCR_SM_MASK) {
+               if (memcmp(fpr_zero, fpr_out, sizeof(fpr_out)) != 0) {
+                       ksft_print_msg("%s FPSIMD registers non-zero exiting SM\n",
+                                      cfg->name);
+                       errors++;
+               }
+       }
+
        return errors;
 }
 
@@ -284,8 +300,8 @@ static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
        return errors;
 }
 
-uint8_t za_in[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
-uint8_t za_out[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
+uint8_t za_in[ZA_SIG_REGS_SIZE(SVE_VQ_MAX)];
+uint8_t za_out[ZA_SIG_REGS_SIZE(SVE_VQ_MAX)];
 
 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
                     uint64_t svcr)
@@ -311,6 +327,35 @@ static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
        return errors;
 }
 
+uint8_t zt_in[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
+uint8_t zt_out[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
+
+static void setup_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
+                    uint64_t svcr)
+{
+       fill_random(zt_in, sizeof(zt_in));
+       memset(zt_out, 0, sizeof(zt_out));
+}
+
+static int check_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
+                   uint64_t svcr)
+{
+       int errors = 0;
+
+       if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2))
+               return 0;
+
+       if (!(svcr & SVCR_ZA_MASK))
+               return 0;
+
+       if (memcmp(zt_in, zt_out, sizeof(zt_in)) != 0) {
+               ksft_print_msg("SME VL %d ZT does not match\n", sme_vl);
+               errors++;
+       }
+
+       return errors;
+}
+
 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
                         uint64_t svcr);
 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
@@ -334,6 +379,7 @@ static struct {
        { setup_ffr, check_ffr },
        { setup_svcr, check_svcr },
        { setup_za, check_za },
+       { setup_zt, check_zt },
 };
 
 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
@@ -355,73 +401,78 @@ static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
 
 static void test_one_syscall(struct syscall_cfg *cfg)
 {
-       int sve_vq, sve_vl;
-       int sme_vq, sme_vl;
+       int sve, sme;
+       int ret;
 
        /* FPSIMD only case */
        ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
                         "%s FPSIMD\n", cfg->name);
 
-       if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
-               return;
-
-       for (sve_vq = SVE_VQ_MAX; sve_vq > 0; --sve_vq) {
-               sve_vl = prctl(PR_SVE_SET_VL, sve_vq * 16);
-               if (sve_vl == -1)
+       for (sve = 0; sve < sve_vl_count; sve++) {
+               ret = prctl(PR_SVE_SET_VL, sve_vls[sve]);
+               if (ret == -1)
                        ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
                                           strerror(errno), errno);
 
-               sve_vl &= PR_SVE_VL_LEN_MASK;
-
-               if (sve_vq != sve_vq_from_vl(sve_vl))
-                       sve_vq = sve_vq_from_vl(sve_vl);
-
-               ksft_test_result(do_test(cfg, sve_vl, default_sme_vl, 0),
-                                "%s SVE VL %d\n", cfg->name, sve_vl);
-
-               if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
-                       continue;
+               ksft_test_result(do_test(cfg, sve_vls[sve], default_sme_vl, 0),
+                                "%s SVE VL %d\n", cfg->name, sve_vls[sve]);
 
-               for (sme_vq = SVE_VQ_MAX; sme_vq > 0; --sme_vq) {
-                       sme_vl = prctl(PR_SME_SET_VL, sme_vq * 16);
-                       if (sme_vl == -1)
+               for (sme = 0; sme < sme_vl_count; sme++) {
+                       ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
+                       if (ret == -1)
                                ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
                                                   strerror(errno), errno);
 
-                       sme_vl &= PR_SME_VL_LEN_MASK;
-
-                       if (sme_vq != sve_vq_from_vl(sme_vl))
-                               sme_vq = sve_vq_from_vl(sme_vl);
-
-                       ksft_test_result(do_test(cfg, sve_vl, sme_vl,
+                       ksft_test_result(do_test(cfg, sve_vls[sve],
+                                                sme_vls[sme],
                                                 SVCR_ZA_MASK | SVCR_SM_MASK),
                                         "%s SVE VL %d/SME VL %d SM+ZA\n",
-                                        cfg->name, sve_vl, sme_vl);
-                       ksft_test_result(do_test(cfg, sve_vl, sme_vl,
-                                                SVCR_SM_MASK),
+                                        cfg->name, sve_vls[sve],
+                                        sme_vls[sme]);
+                       ksft_test_result(do_test(cfg, sve_vls[sve],
+                                                sme_vls[sme], SVCR_SM_MASK),
                                         "%s SVE VL %d/SME VL %d SM\n",
-                                        cfg->name, sve_vl, sme_vl);
-                       ksft_test_result(do_test(cfg, sve_vl, sme_vl,
-                                                SVCR_ZA_MASK),
+                                        cfg->name, sve_vls[sve],
+                                        sme_vls[sme]);
+                       ksft_test_result(do_test(cfg, sve_vls[sve],
+                                                sme_vls[sme], SVCR_ZA_MASK),
                                         "%s SVE VL %d/SME VL %d ZA\n",
-                                        cfg->name, sve_vl, sme_vl);
+                                        cfg->name, sve_vls[sve],
+                                        sme_vls[sme]);
                }
        }
+
+       for (sme = 0; sme < sme_vl_count; sme++) {
+               ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
+               if (ret == -1)
+                       ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
+                                                  strerror(errno), errno);
+
+               ksft_test_result(do_test(cfg, 0, sme_vls[sme],
+                                        SVCR_ZA_MASK | SVCR_SM_MASK),
+                                "%s SME VL %d SM+ZA\n",
+                                cfg->name, sme_vls[sme]);
+               ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_SM_MASK),
+                                "%s SME VL %d SM\n",
+                                cfg->name, sme_vls[sme]);
+               ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_ZA_MASK),
+                                "%s SME VL %d ZA\n",
+                                cfg->name, sme_vls[sme]);
+       }
 }
 
-int sve_count_vls(void)
+void sve_count_vls(void)
 {
        unsigned int vq;
-       int vl_count = 0;
        int vl;
 
        if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
-               return 0;
+               return;
 
        /*
         * Enumerate up to SVE_VQ_MAX vector lengths
         */
-       for (vq = SVE_VQ_MAX; vq > 0; --vq) {
+       for (vq = SVE_VQ_MAX; vq > 0; vq /= 2) {
                vl = prctl(PR_SVE_SET_VL, vq * 16);
                if (vl == -1)
                        ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
@@ -432,28 +483,22 @@ int sve_count_vls(void)
                if (vq != sve_vq_from_vl(vl))
                        vq = sve_vq_from_vl(vl);
 
-               vl_count++;
+               sve_vls[sve_vl_count++] = vl;
        }
-
-       return vl_count;
 }
 
-int sme_count_vls(void)
+void sme_count_vls(void)
 {
        unsigned int vq;
-       int vl_count = 0;
        int vl;
 
        if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
-               return 0;
-
-       /* Ensure we configure a SME VL, used to flag if SVCR is set */
-       default_sme_vl = 16;
+               return;
 
        /*
         * Enumerate up to SVE_VQ_MAX vector lengths
         */
-       for (vq = SVE_VQ_MAX; vq > 0; --vq) {
+       for (vq = SVE_VQ_MAX; vq > 0; vq /= 2) {
                vl = prctl(PR_SME_SET_VL, vq * 16);
                if (vl == -1)
                        ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
@@ -461,31 +506,47 @@ int sme_count_vls(void)
 
                vl &= PR_SME_VL_LEN_MASK;
 
+               /* Found lowest VL */
+               if (sve_vq_from_vl(vl) > vq)
+                       break;
+
                if (vq != sve_vq_from_vl(vl))
                        vq = sve_vq_from_vl(vl);
 
-               vl_count++;
+               sme_vls[sme_vl_count++] = vl;
        }
 
-       return vl_count;
+       /* Ensure we configure a SME VL, used to flag if SVCR is set */
+       default_sme_vl = sme_vls[0];
 }
 
 int main(void)
 {
        int i;
        int tests = 1;  /* FPSIMD */
+       int sme_ver;
 
        srandom(getpid());
 
        ksft_print_header();
-       tests += sve_count_vls();
-       tests += (sve_count_vls() * sme_count_vls()) * 3;
+
+       sve_count_vls();
+       sme_count_vls();
+
+       tests += sve_vl_count;
+       tests += sme_vl_count * 3;
+       tests += (sve_vl_count * sme_vl_count) * 3;
        ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
 
+       if (getauxval(AT_HWCAP2) & HWCAP2_SME2)
+               sme_ver = 2;
+       else
+               sme_ver = 1;
+
        if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
-               ksft_print_msg("SME with FA64\n");
+               ksft_print_msg("SME%d with FA64\n", sme_ver);
        else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
-               ksft_print_msg("SME without FA64\n");
+               ksft_print_msg("SME%d without FA64\n", sme_ver);
 
        for (i = 0; i < ARRAY_SIZE(syscalls); i++)
                test_one_syscall(&syscalls[i]);