Merge tag 'drm-msm-fixes-2021-05-09' of https://gitlab.freedesktop.org/drm/msm into...
[linux-2.6-microblaze.git] / tools / perf / util / zstd.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #include <string.h>
4
5 #include "util/compress.h"
6 #include "util/debug.h"
7
8 int zstd_init(struct zstd_data *data, int level)
9 {
10         size_t ret;
11
12         data->dstream = ZSTD_createDStream();
13         if (data->dstream == NULL) {
14                 pr_err("Couldn't create decompression stream.\n");
15                 return -1;
16         }
17
18         ret = ZSTD_initDStream(data->dstream);
19         if (ZSTD_isError(ret)) {
20                 pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret));
21                 return -1;
22         }
23
24         if (!level)
25                 return 0;
26
27         data->cstream = ZSTD_createCStream();
28         if (data->cstream == NULL) {
29                 pr_err("Couldn't create compression stream.\n");
30                 return -1;
31         }
32
33         ret = ZSTD_initCStream(data->cstream, level);
34         if (ZSTD_isError(ret)) {
35                 pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret));
36                 return -1;
37         }
38
39         return 0;
40 }
41
42 int zstd_fini(struct zstd_data *data)
43 {
44         if (data->dstream) {
45                 ZSTD_freeDStream(data->dstream);
46                 data->dstream = NULL;
47         }
48
49         if (data->cstream) {
50                 ZSTD_freeCStream(data->cstream);
51                 data->cstream = NULL;
52         }
53
54         return 0;
55 }
56
57 size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
58                                        void *src, size_t src_size, size_t max_record_size,
59                                        size_t process_header(void *record, size_t increment))
60 {
61         size_t ret, size, compressed = 0;
62         ZSTD_inBuffer input = { src, src_size, 0 };
63         ZSTD_outBuffer output;
64         void *record;
65
66         while (input.pos < input.size) {
67                 record = dst;
68                 size = process_header(record, 0);
69                 compressed += size;
70                 dst += size;
71                 dst_size -= size;
72                 output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
73                                                 max_record_size : dst_size, 0 };
74                 ret = ZSTD_compressStream(data->cstream, &output, &input);
75                 ZSTD_flushStream(data->cstream, &output);
76                 if (ZSTD_isError(ret)) {
77                         pr_err("failed to compress %ld bytes: %s\n",
78                                 (long)src_size, ZSTD_getErrorName(ret));
79                         memcpy(dst, src, src_size);
80                         return src_size;
81                 }
82                 size = output.pos;
83                 size = process_header(record, size);
84                 compressed += size;
85                 dst += size;
86                 dst_size -= size;
87         }
88
89         return compressed;
90 }
91
92 size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
93                               void *dst, size_t dst_size)
94 {
95         size_t ret;
96         ZSTD_inBuffer input = { src, src_size, 0 };
97         ZSTD_outBuffer output = { dst, dst_size, 0 };
98
99         while (input.pos < input.size) {
100                 ret = ZSTD_decompressStream(data->dstream, &output, &input);
101                 if (ZSTD_isError(ret)) {
102                         pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
103                                src_size, output.size, dst_size, ZSTD_getErrorName(ret));
104                         break;
105                 }
106                 output.dst  = dst + output.pos;
107                 output.size = dst_size - output.pos;
108         }
109
110         return output.pos;
111 }