forked from scylladb/scylladb
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstream_compressor.cc
505 lines (452 loc) · 20.8 KB
/
stream_compressor.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
/*
* Copyright (C) 2023-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#include "utils/stream_compressor.hh"
#include <array>
#include <memory>
#include <bit>
#include <seastar/core/byteorder.hh>
#include <seastar/rpc/rpc_types.hh>
#include <seastar/core/memory.hh>
#include "utils/small_vector.hh"
#include "seastarx.hh"
#include "utils/crc.hh"
#define ZSTD_STATIC_LINKING_ONLY
#include <zstd.h>
namespace utils {
namespace {
size_t varint_length(uint8_t first_byte) {
return 1 + std::countr_zero(first_byte);
}
uint32_t varint_decode(uint64_t all_bytes) {
size_t n_bytes = 1 + std::countr_zero(all_bytes);
size_t ignored_msb = (64 - 8 * n_bytes);
return ((all_bytes << ignored_msb) >> ignored_msb) >> n_bytes;
}
std::pair<size_t, uint64_t> varint_encode(uint32_t x) {
size_t n_bytes = (std::max(std::bit_width(x), 1) + 6) / 7;
return {n_bytes, ((x << 1) | 1) << (n_bytes - 1)};
}
} // namespace
size_t raw_stream::copy(ZSTD_outBuffer* out, ZSTD_inBuffer* in) {
size_t n = std::min(out->size - out->pos, in->size - in->pos);
memcpy((char*)out->dst + out->pos, (const char*)in->src + in->pos, n);
in->pos += n;
out->pos += n;
return 0;
}
void raw_stream::decompress(ZSTD_outBuffer* out, ZSTD_inBuffer* in, bool) {
copy(out, in);
}
size_t raw_stream::compress(ZSTD_outBuffer* out, ZSTD_inBuffer* in, ZSTD_EndDirective end) {
return copy(out, in);
}
void raw_stream::reset() noexcept {
}
// Throw if ret is an ZSTD error code.
static void check_zstd(size_t ret, const char* text) {
if (ZSTD_isError(ret)) {
throw std::runtime_error(fmt::format("{} error: {}", text, ZSTD_getErrorName(ret)));
}
}
zstd_dstream::zstd_dstream() {
_ctx.reset(ZSTD_createDStream());
check_zstd(ZSTD_DCtx_setParameter(_ctx.get(), ZSTD_d_format, ZSTD_f_zstd1_magicless), "ZSTD_CCtx_setParameter(.., ZSTD_c_format, ZSTD_f_zstd1_magicless)");
if (!_ctx) {
throw std::bad_alloc();
}
}
void zstd_dstream::reset() noexcept {
ZSTD_DCtx_reset(_ctx.get(), ZSTD_reset_session_only);
_in_progress = false;
}
void zstd_dstream::decompress(ZSTD_outBuffer* out, ZSTD_inBuffer* in, bool end_of_frame) {
if (!_in_progress && in->pos == in->size) {
// Without this early return, we would start the decompression of the next frame.
// ZSTD_decompressStream() would return something positive, and our truncation
// checking logic could wrongly claim truncation.
return;
}
size_t ret = ZSTD_decompressStream(_ctx.get(), out, in);
check_zstd(ret, "ZSTD_decompressStream");
_in_progress = bool(ret);
if (in->pos == in->size && end_of_frame && _in_progress) {
// Should never happen in practice, barring a bug/corruption.
// It's not the compressor's job to verify data integrity,
// so we could as well not check for this.
//
// But we do it as a good practice.
// To an extent, it keeps the effects of a corruption localized to the corrupted message.
// This could make some possible debugging easier.
// (If we didn't check for truncation, a bad message could leave leftovers in the compressor,
// which would cause weirdness when decompressing the next -- perhaps correct -- message).
throw std::runtime_error("truncated ZSTD frame");
}
}
void zstd_dstream::set_dict(const ZSTD_DDict* dict) {
_dict = dict;
ZSTD_DCtx_refDDict(_ctx.get(), _dict);
}
zstd_cstream::zstd_cstream() {
_ctx.reset(ZSTD_createCStream());
if (!_ctx) {
throw std::bad_alloc();
}
// For now, we hardcode a 128 kiB window and the lowest compression level here.
check_zstd(ZSTD_initCStream(_ctx.get(), 1), "ZSTD_initCStream(.., 1)");
check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_format, ZSTD_f_zstd1_magicless), "ZSTD_CCtx_setParameter(.., ZSTD_c_format, ZSTD_f_zstd1_magicless)");
check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_contentSizeFlag, 0), "ZSTD_CCtx_setParameter(.., ZSTD_c_contentSizeFlag, 0)");
check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_checksumFlag, 0), "ZSTD_CCtx_setParameter(.., ZSTD_c_checksumFlag, 0)");
check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_dictIDFlag, 0), "ZSTD_CCtx_setParameter(.., ZSTD_c_dictIDFlag, 0)");
check_zstd(ZSTD_CCtx_setParameter(_ctx.get(), ZSTD_c_windowLog, 17), "ZSTD_CCtx_setParameter(.., ZSTD_c_windowLog, 17)");
}
size_t zstd_cstream::compress(ZSTD_outBuffer* out, ZSTD_inBuffer* in, ZSTD_EndDirective end) {
size_t ret = ZSTD_compressStream2(_ctx.get(), out, in, end);
check_zstd(ret, "ZSTD_compressStream2");
return ret;
}
void zstd_cstream::reset() noexcept {
ZSTD_CCtx_reset(_ctx.get(), ZSTD_reset_session_only);
}
void zstd_cstream::set_dict(const ZSTD_CDict* dict) {
_dict = dict;
ZSTD_CCtx_refCDict(_ctx.get(), _dict);
}
// Used as a intermediary buffer by lz4_cstream and lz4_dstream;
// Since this buffer is shared, lz4_cstream and lz4_dstream must be used synchronously.
// That is, when the processing of a frame (the string of data between `reset()`s) starts,
// it has to be finished before another compressor instance attempts to use it.
//
// If there ever is a need to process multiple frames concurrently, the dependency on this
// buffer has to be eliminated, and each concurrent stream must be given a separate buffer.
//
// This makes for a bad/dangerous API, but it's only used in this file, so it's okay for now.
static thread_local std::array<char, LZ4_COMPRESSBOUND(max_lz4_window_size) + sizeof(uint64_t)> _lz4_scratch;
lz4_cstream::lz4_cstream(size_t window_size)
: _buf(window_size)
{
reset();
}
void lz4_cstream::reset() noexcept {
_scratch_beg = _scratch_end = _buf_pos = 0;
LZ4_initStream(&_ctx, sizeof(_ctx));
LZ4_attach_dictionary(&_ctx, _dict);
}
void lz4_cstream::resetFast() noexcept {
_scratch_beg = _scratch_end = _buf_pos = 0;
LZ4_resetStream_fast(&_ctx);
LZ4_attach_dictionary(&_ctx, _dict);
}
// When new data arrives in `in`, we copy an arbitrary amount of it to `_buf`,
// (the amount is arbitrary, but it has to fit contiguously in `_buf`),
// compress the new block from `_buf` to `_lz4_scratch`,
// then we copy everything from `_lz4_scratch` to `out`.
// Repeat until `in` is empty.
size_t lz4_cstream::compress(ZSTD_outBuffer* out, ZSTD_inBuffer* in, ZSTD_EndDirective end) {
if (_scratch_end == _scratch_beg && in->size != in->pos) {
// If we already copied everything we compressed to `out`,
// and there is still some data in `in`,
// we are going to compress a new block from `in` to `_lz4_scratch`.
if (_buf_pos == _buf.size()) {
// The ring buffer wraps around here.
_buf_pos = 0;
}
// We will compress the biggest prefix of `in` that fits contiguously inside `buf`.
// In principle, this is slightly suboptimal -- ideally, if `in` is smaller than the contiguous space in `buf`,
// we should only copy `in` to `buf` and wait with the compressor call until future `in`s
// fill the contiguous space entirely, or `end` is `ZSTD_e_flush` or `ZSTD_e_end`.
// But for streaming LZ4 it doesn't really make a difference.
size_t n = std::min(_buf.size() - _buf_pos, in->size - in->pos);
// The compressed block mustn't fill the entire ring buffer.
// It must be at least 1 byte shorter. It's a dumb quirk of lz4. Perhaps it should be called a bug.
n = std::min(max_lz4_window_size - 1, n);
std::memcpy(_buf.data() + _buf_pos, static_cast<const char*>(in->src) + in->pos, n);
in->pos += n;
// The first header_size bytes contain the length of the compressed block, so we compress to _lz4_scratch.data() + sizeof(uint64_t).
int x = LZ4_compress_fast_continue(&_ctx, _buf.data() + _buf_pos, _lz4_scratch.data() + sizeof(uint64_t), n, _lz4_scratch.size(), 1);
if (x < 0) {
throw std::runtime_error(fmt::format(
"LZ4_compress_fast_continue failed with negative return value {}. "
"LZ4 shouldn't fail. This indicates a bug or a corruption.",
x));
}
auto [header_size, header_value] = varint_encode(x);
_scratch_end = sizeof(uint64_t) + x;
_scratch_beg = sizeof(uint64_t) - header_size;
_buf_pos += n;
// We prepend to every block its compressed length.
// This is necessary, because LZ4 needs to know it, but it doesn't store it by itself.
seastar::write_le<uint64_t>(_lz4_scratch.data(), header_value << (8 * (sizeof(uint64_t) - header_size)));
}
if (_scratch_end > _scratch_beg) {
// If we compressed a block to _lz4_scratch, but we still haven't copied it all to `out`,
// we copy as much as possible to `out`.
size_t n = std::min(_scratch_end - _scratch_beg, out->size - out->pos);
std::memcpy(static_cast<char*>(out->dst) + out->pos, _lz4_scratch.data() + _scratch_beg, n);
_scratch_beg += n;
out->pos += n;
}
if (_scratch_beg == _scratch_end && in->size == in->pos && end == ZSTD_e_end) {
// If we have no more data in `_lz4_scratch`, there is no more data in `in`, and `end`
// says that this is the last block in the stream, we reset the stream state.
resetFast();
// Note that the we currently don't emit anything to mark end of stream.
// We rely on the outer layer to notify the decompressor about it manually.
// This is inconsistent with what zstd's API does. We could fix that for elegance points.
};
// Return a non-zero value if there is still some data lingering in the compressor's
// internal buffers (`_lz4_scratch`), which has to be copied out.
// This will prompt the caller to call us again, even if `in` is empty.
return _scratch_beg != _scratch_end;
}
// The passed dict must live until it is unset by another set_dict(). (Or until the compressor is destroyed).
// The pointer can be null, this will unset the current dict.
void lz4_cstream::set_dict(const LZ4_stream_t* dict) {
_dict = dict;
resetFast();
}
lz4_dstream::lz4_dstream(size_t window_size)
: _buf(window_size)
{}
void lz4_dstream::reset() noexcept {
_scratch_pos = _buf_end = _buf_beg = 0;
LZ4_setStreamDecode(&_ctx, reinterpret_cast<const char*>(_dict.size() ? _dict.data() : nullptr), _dict.size());
}
void lz4_dstream::decompress(ZSTD_outBuffer* out, ZSTD_inBuffer* in, bool end_of_frame) {
if (_buf_beg == _buf_end) {
// If we have no decompressed data that wasn't yet output,
// we will decompress a new block.
if (_buf_end == _buf.size()) {
// The ring buffer wraps around here.
_buf_end = _buf_beg = 0;
}
// First, we have to ingest the first few bytes of input data, which contain
// the post-compression size of the following block.
size_t header_size = 1;
if (_scratch_pos < header_size) {
size_t n = std::min(header_size - _scratch_pos, in->size - in->pos);
std::memcpy(_lz4_scratch.data() + _scratch_pos, static_cast<const char*>(in->src) + in->pos, n);
_scratch_pos += n;
in->pos += n;
}
if (_scratch_pos >= header_size) {
header_size = varint_length(static_cast<uint8_t>(_lz4_scratch[0]));
}
if (_scratch_pos < header_size) {
size_t n = std::min(header_size - _scratch_pos, in->size - in->pos);
std::memcpy(_lz4_scratch.data() + _scratch_pos, static_cast<const char*>(in->src) + in->pos, n);
_scratch_pos += n;
in->pos += n;
}
// If we know the first header_size bytes, we can read the compressed size and ingest that many bytes,
// so we have a full compressed block in contiguous memory.
if (_scratch_pos >= header_size) {
auto x = varint_decode(seastar::read_le<uint32_t>(_lz4_scratch.data()));
if (x + header_size > _lz4_scratch.size()) {
throw std::runtime_error(fmt::format("Oversized LZ4 block: {} bytes.", x + header_size));
}
size_t n = std::min(x + header_size - _scratch_pos, in->size - in->pos);
std::memcpy(_lz4_scratch.data() + _scratch_pos, static_cast<const char*>(in->src) + in->pos, n);
_scratch_pos += n;
in->pos += n;
if (_scratch_pos == x + header_size) {
// Now the full compressed block is in contiguous memory.
// We can decompress it to `_buf` and clear `_lz4_scratch`.
size_t n_buf = _buf.size() - _buf_end;
int ret = LZ4_decompress_safe_continue(&_ctx, _lz4_scratch.data() + header_size, _buf.data() + _buf_end, x, n_buf);
if (ret < 0) {
throw std::runtime_error(fmt::format("LZ4_decompress_safe_continue failed with negative return value {}. This indicates a corruption of the RPC connection.", ret));
}
_buf_end += ret;
_scratch_pos = 0;
}
}
}
if (_buf_end > _buf_beg) {
// If we have some decompressed data that wasn't yet output,
// we output as much as possible.
size_t n = std::min(_buf_end - _buf_beg, out->size - out->pos);
std::memcpy(static_cast<char*>(out->dst) + out->pos, _buf.data() + _buf_beg, n);
out->pos += n;
_buf_beg += n;
}
if (end_of_frame && _buf_beg == _buf_end && in->size == in->pos) {
// If there is no data in internal buffers to be output, there is no more data in `in`,
// and `in` was the last block in the stream, we reset the stream state.
if (_scratch_pos != 0) {
throw std::runtime_error(fmt::format("Truncated LZ4 frame."));
}
reset();
}
// Sanity check.
assert(_buf_end >= _buf_beg);
}
// The passed dict must live until it is unset by another set_dict(). (Or until the decompressor is destroyed).
// The span can be empty, this will unset the current dict.
void lz4_dstream::set_dict(std::span<const std::byte> dict) {
_dict = dict;
reset();
}
rpc::snd_buf compress_impl(size_t head_space, const rpc::snd_buf& data, stream_compressor& compressor, bool end_of_frame, size_t chunk_size) try {
const auto size = data.size;
auto src = std::get_if<temporary_buffer<char>>(&data.bufs);
if (!src) {
src = std::get<std::vector<temporary_buffer<char>>>(data.bufs).data();
}
size_t size_left = size;
size_t size_compressed = 0;
ZSTD_inBuffer inbuf = {};
ZSTD_outBuffer outbuf = {};
small_vector<temporary_buffer<char>, 16> dst_buffers;
// Note: we always allocate chunk_size here, then we resize it to fit at the end.
// Maybe that's a waste of cycles, and we should allocate a buffer that's about as big
// as the input, and not shrink at the end?
dst_buffers.emplace_back(std::max<size_t>(head_space, chunk_size));
outbuf.dst = dst_buffers.back().get_write();
outbuf.size = dst_buffers.back().size();
outbuf.pos = head_space;
// Stream compressors can handle frames of size 0,
// but when a Zstd compressor is called multiple times (around 10, I think?)
// with an empty input, it tries to be helpful and emits a "no progress" error.
//
// To avoid that, we handle empty frames we special-case empty frames with the `if`
// below, and we just return an empty buffer as the result,
// without putting the empty message through the compressor.
if (size > 0) {
while (true) {
if (size_left && inbuf.pos == inbuf.size) {
size_left -= src->size();
inbuf.src = src->get();
inbuf.size = src->size();
inbuf.pos = 0;
++src;
continue;
}
if (outbuf.pos == outbuf.size) {
size_compressed += outbuf.pos;
dst_buffers.emplace_back(chunk_size);
outbuf.dst = dst_buffers.back().get_write();
outbuf.size = dst_buffers.back().size();
outbuf.pos = 0;
continue;
}
size_t ret = compressor.compress(&outbuf, &inbuf, size_left ? ZSTD_e_continue : (end_of_frame ? ZSTD_e_end : ZSTD_e_flush));
if (!size_left // No more input chunks.
&& inbuf.pos == inbuf.size // No more data in the last chunk.
&& ret == 0 // No data remaining in compressor's internal buffers.
) {
break;
}
}
}
size_compressed += outbuf.pos;
dst_buffers.back().trim(outbuf.pos);
// In this routine, we always allocate fragments of size chunk_size, even
// if the message is tiny.
// Hence, at this point, dst_buffers contains up to (128 kiB) unused memory.
// When there are many concurrent RPC messages, this waste might add up to
// a considerable overhead.
// Let's pay some cycles to shrink the underlying allocation to fit, to
// avoid any problems with that unseen memory usage.
dst_buffers.back() = dst_buffers.back().clone();
if (dst_buffers.size() == 1) {
return rpc::snd_buf(std::move(dst_buffers.front()));
}
return rpc::snd_buf({std::make_move_iterator(dst_buffers.begin()), std::make_move_iterator(dst_buffers.end())}, size_compressed);
} catch (...) {
compressor.reset();
throw;
}
template <typename T>
requires std::same_as<T, rpc::rcv_buf> || std::same_as<T, rpc::snd_buf>
uint32_t crc_impl_generic(const T& data) noexcept {
auto size = data.size;
auto it = std::get_if<temporary_buffer<char>>(&data.bufs);
if (!it) {
it = std::get<std::vector<temporary_buffer<char>>>(data.bufs).data();
}
utils::crc32 crc;
while (size > 0) {
crc.process(reinterpret_cast<const uint8_t*>(it->get()), it->size());
size -= it->size();
++it;
}
return crc.get();
}
uint32_t crc_impl(const rpc::snd_buf& data) noexcept {
return crc_impl_generic(data);
}
uint32_t crc_impl(const rpc::rcv_buf& data) noexcept {
return crc_impl_generic(data);
}
rpc::rcv_buf decompress_impl(const seastar::rpc::rcv_buf& data, stream_decompressor& decompressor, bool end_of_frame, size_t chunk_size) try {
// As a special-case, if we take empty input, we immediately return an empty output,
// without going through the actual compressors.
//
// This is due to a quirk of Zstd, see the matching comment in `decompress_impl`
// for details.
if (data.size == 0) {
return rpc::rcv_buf();
}
size_t size_decompressed = 0;
size_t input_remaining = data.size;
auto it = std::get_if<temporary_buffer<char>>(&data.bufs);
if (!it) {
it = std::get<std::vector<temporary_buffer<char>>>(data.bufs).data();
}
small_vector<temporary_buffer<char>, 16> dst_buffers;
ZSTD_inBuffer inbuf = {};
ZSTD_outBuffer outbuf = {};
while (true) {
if (input_remaining && inbuf.pos == inbuf.size) {
inbuf.src = it->get();
inbuf.size = it->size();
input_remaining -= it->size();
++it;
inbuf.pos = 0;
continue;
}
if (outbuf.pos == outbuf.size) {
size_decompressed += outbuf.pos;
dst_buffers.emplace_back(chunk_size);
outbuf.dst = dst_buffers.back().get_write();
outbuf.size = dst_buffers.back().size();
outbuf.pos = 0;
continue;
}
decompressor.decompress(&outbuf, &inbuf, end_of_frame && !input_remaining);
if (!input_remaining // No more input chunks.
&& inbuf.pos == inbuf.size // No more data in the last input chunk.
&& outbuf.pos < outbuf.size // No more data in decompressor's internal buffers.
) {
break;
}
}
size_decompressed += outbuf.pos;
if (!dst_buffers.empty()) {
dst_buffers.back().trim(outbuf.pos);
// In this routine, we always allocate fragments of size chunk_size, even
// if the message is tiny.
// Hence, at this point, dst_buffers contains up to (128 kiB) unused memory.
// When there are many concurrent RPC messages, this waste might add up to
// a considerable overhead.
// Let's pay some cycles to shrink the underlying allocation to fit, to
// avoid any problems with that unseen memory usage.
// We could avoid the CPU cost if we prepended the decompressed size to
// the message (thus growing the messages by 1-2 bytes) and used that to
// allocate exactly the right buffer.
// But I don't know if it's worth the 2 bytes.
dst_buffers.back() = dst_buffers.back().clone();
}
if (dst_buffers.size() == 1) {
return rpc::rcv_buf(std::move(dst_buffers.front()));
}
return rpc::rcv_buf({std::make_move_iterator(dst_buffers.begin()), std::make_move_iterator(dst_buffers.end())}, size_decompressed);
} catch (...) {
decompressor.reset();
throw;
}
} // namespace utils