Skip to content

Commit 3edb410

Browse files
authored
feat: Handle account shrinking/expansion in merge_diff_copy (#116)
* feat: Handle account shrinking/expansion in merge_diff_copy * Return unwritten bytes instead of the length of written bytes
1 parent 1bcaf4d commit 3edb410

File tree

2 files changed

+223
-32
lines changed

2 files changed

+223
-32
lines changed

src/diff/algorithm.rs

Lines changed: 221 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -232,33 +232,113 @@ pub fn apply_diff_copy(original: &[u8], diffset: &DiffSet<'_>) -> Result<Vec<u8>
232232
})
233233
}
234234

235-
/// This function constructs destination by merging original with diff such that destination
236-
/// becomes the changed version of the original.
235+
/// Constructs destination by applying the diff to original, such that destination becomes the
236+
/// post-diff state of the original.
237237
///
238238
/// Precondition:
239-
/// - destination.len() == original.len()
240-
pub fn merge_diff_copy(
241-
destination: &mut [u8],
239+
/// - destination.len() == diffset.changed_len()
240+
/// - original.len() may differ from destination.len() to allow Solana
241+
/// account resizing (shrink or expand).
242+
/// Assumption:
243+
/// - destination is assumed to be zero-initialized. That automatically holds true for freshly
244+
/// allocated Solana account data. The function does NOT validate this assumption for performance reason.
245+
/// Returns:
246+
/// - Ok(&mut [u8]) where the slice contains the trailing unwritten bytes in destination and are
247+
/// assumed to be already zero-initialized. Callers may write to those bytes or validate it.
248+
/// Notes:
249+
/// - Merge consists of:
250+
/// - bytes covered by diff segments are written from diffset.
251+
/// - unmodified regions are copied directly from original.
252+
/// - In shrink case, extra trailing bytes from original are ignored.
253+
/// - In expansion case, any remaining bytes beyond both the diff coverage
254+
/// and original.len() stay unwritten and are assumed to be zero-initialized.
255+
///
256+
pub fn merge_diff_copy<'a>(
257+
destination: &'a mut [u8],
242258
original: &[u8],
243259
diffset: &DiffSet<'_>,
244-
) -> Result<(), ProgramError> {
245-
if destination.len() != original.len() {
260+
) -> Result<&'a mut [u8], ProgramError> {
261+
if destination.len() != diffset.changed_len() {
246262
return Err(DlpError::MergeDiffError.into());
247263
}
264+
248265
let mut write_index = 0;
249266
for item in diffset.iter() {
250267
let (diff_segment, OffsetInData { start, end }) = item?;
268+
251269
if write_index < start {
270+
if start > original.len() {
271+
return Err(DlpError::InvalidDiff.into());
272+
}
252273
// copy the unchanged bytes
253274
destination[write_index..start].copy_from_slice(&original[write_index..start]);
254275
}
276+
255277
destination[start..end].copy_from_slice(diff_segment);
256278
write_index = end;
257279
}
258-
if write_index < original.len() {
259-
destination[write_index..].copy_from_slice(&original[write_index..]);
260-
}
261-
Ok(())
280+
281+
// Ensure we have overwritten all bytes in destination, otherwise "construction" of destination
282+
// will be considered incomplete.
283+
let num_bytes_written = match write_index.cmp(&destination.len()) {
284+
Ordering::Equal => {
285+
// It means the destination is fully constructed.
286+
// Nothing to do here.
287+
288+
// It is possible that destination.len() <= original.len() i.e destination might have shrunk
289+
// in which case we do not care about those bytes of original which are not part of
290+
// destination anymore.
291+
write_index
292+
}
293+
Ordering::Less => {
294+
// destination is NOT fully constructed yet. Few bytes in the destination are still unwritten.
295+
// Let's say the number of these unwritten bytes is: N.
296+
//
297+
// Now how do we construct these N unwritten bytes? We have already processed the
298+
// diffset, so now where could the values for these N bytes come from?
299+
//
300+
// There are 3 scenarios:
301+
// - All N bytes must be copied from remaining region of the original:
302+
// - that means, destination.len() <= original.len()
303+
// - and the destination might have shrunk, in which case we do not care about
304+
// the extra bytes in the original: they're discarded.
305+
// - Only (N-M) bytes come from original and the rest M bytes stay unwritten and are
306+
// "assumed" to be already zero-initialized.
307+
// - that means, destination.len() > original.len()
308+
// - write_index + (N-M) == original.len()
309+
// - and the destination has expanded.
310+
// - None of these N bytes come from original. It's basically a special case of
311+
// the second scenario: when M = N i.e all N bytes stay unwritten.
312+
// - that means, destination.len() > original.len()
313+
// - and also, write_index == original.len().
314+
// - the destination has expanded just like the above case.
315+
// - all N bytes are "assumed" to be already zero-initialized (by the caller)
316+
317+
if destination.len() <= original.len() {
318+
// case: all n bytes come from original
319+
let dest_len = destination.len();
320+
destination[write_index..].copy_from_slice(&original[write_index..dest_len]);
321+
dest_len
322+
} else if write_index < original.len() {
323+
// case: some bytes come from original and the rest are "assumed" to be
324+
// zero-initialized (by the caller).
325+
destination[write_index..original.len()].copy_from_slice(&original[write_index..]);
326+
original.len()
327+
} else {
328+
// case: all N bytes are "assumed" to be zero-initialized (by the caller).
329+
write_index
330+
}
331+
}
332+
Ordering::Greater => {
333+
// It is an impossible scenario. Even if the diff is corrupt, or the lengths of destinatiare are same
334+
// or different, we'll not encounter this case. It only implies logic error.
335+
return Err(DlpError::InfallibleError.into());
336+
}
337+
};
338+
339+
let (_, unwritten_bytes) = destination.split_at_mut(num_bytes_written);
340+
341+
Ok(unwritten_bytes)
262342
}
263343

264344
// private function that does the actual work.
@@ -297,6 +377,58 @@ mod tests {
297377
);
298378
}
299379

380+
fn get_example_expected_diff(
381+
changed_len: usize,
382+
// additional_changes must apply after index 78 (index-in-data) !!
383+
additional_changes: Vec<(u32, &[u8])>,
384+
) -> Vec<u8> {
385+
// expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
386+
387+
let mut expected_diff = vec![];
388+
389+
// changed_len (u32)
390+
expected_diff.extend_from_slice(&(changed_len as u32).to_le_bytes());
391+
392+
if additional_changes.is_empty() {
393+
// 2 (u32)
394+
expected_diff.extend_from_slice(&2u32.to_le_bytes());
395+
} else {
396+
expected_diff
397+
.extend_from_slice(&(2u32 + additional_changes.len() as u32).to_le_bytes());
398+
}
399+
400+
// -- offsets
401+
402+
// 0 11 (each u32)
403+
expected_diff.extend_from_slice(&0u32.to_le_bytes());
404+
expected_diff.extend_from_slice(&11u32.to_le_bytes());
405+
406+
// 4 71 (each u32)
407+
expected_diff.extend_from_slice(&4u32.to_le_bytes());
408+
expected_diff.extend_from_slice(&71u32.to_le_bytes());
409+
410+
let mut offset_in_diff = 12u32;
411+
for (offset_in_data, diff) in additional_changes.iter() {
412+
expected_diff.extend_from_slice(&offset_in_diff.to_le_bytes());
413+
expected_diff.extend_from_slice(&offset_in_data.to_le_bytes());
414+
offset_in_diff += diff.len() as u32;
415+
}
416+
417+
// -- segments --
418+
419+
// 11 12 13 14 (each u8)
420+
expected_diff.extend_from_slice(&0x01020304u32.to_le_bytes());
421+
// 71 72 ... 78 (each u8)
422+
expected_diff.extend_from_slice(&0x0102030405060708u64.to_le_bytes());
423+
424+
// append diff from additional_changes
425+
for (_, diff) in additional_changes.iter() {
426+
expected_diff.extend_from_slice(diff);
427+
}
428+
429+
expected_diff
430+
}
431+
300432
#[test]
301433
fn test_using_example_data() {
302434
let original = [0; 100];
@@ -311,42 +443,99 @@ mod tests {
311443

312444
let actual_diff = compute_diff(&original, &changed);
313445
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
314-
let expected_diff = {
315-
// expected: | 100 | 2 | 0 11 | 4 71 | 11 12 13 14 71 72 ... 78 |
446+
let expected_diff = get_example_expected_diff(changed.len(), vec![]);
316447

317-
let mut serialized = vec![];
448+
assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8));
449+
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
318450

319-
// 100 (u32)
320-
serialized.extend_from_slice(&(changed.len() as u32).to_le_bytes());
451+
let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap();
321452

322-
// 2 (u32)
323-
serialized.extend_from_slice(&2u32.to_le_bytes());
453+
assert_eq!(changed.as_slice(), expected_changed.as_slice());
324454

325-
// 0 11 (each u32)
326-
serialized.extend_from_slice(&0u32.to_le_bytes());
327-
serialized.extend_from_slice(&11u32.to_le_bytes());
455+
let expected_changed = {
456+
let mut destination = vec![255; original.len()];
457+
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
458+
destination
459+
};
460+
461+
assert_eq!(changed.as_slice(), expected_changed.as_slice());
462+
}
463+
464+
#[test]
465+
fn test_shrunk_account_data() {
466+
// Note that changed_len cannot be lower than 79 because the last "changed" index is
467+
// 78 in the diff.
468+
const CHANGED_LEN: usize = 80;
328469

329-
// 4 71 (each u32)
330-
serialized.extend_from_slice(&4u32.to_le_bytes());
331-
serialized.extend_from_slice(&71u32.to_le_bytes());
470+
let original = vec![0; 100];
471+
let changed = {
472+
let mut copy = original.clone();
473+
copy.truncate(CHANGED_LEN);
332474

333-
// 11 12 13 14 (each u8)
334-
serialized.extend_from_slice(&0x01020304u32.to_le_bytes());
335-
// 71 72 ... 78 (each u8)
336-
serialized.extend_from_slice(&0x0102030405060708u64.to_le_bytes());
337-
serialized
475+
// | 11 | 12 | 13 | 14 |
476+
copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes());
477+
// | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
478+
copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes());
479+
copy
338480
};
339481

482+
let actual_diff = compute_diff(&original, &changed);
483+
484+
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
485+
486+
let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![]);
487+
340488
assert_eq!(actual_diff.len(), 4 + 4 + 8 + 8 + (4 + 8));
341489
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
342490

343-
let expected_changed = apply_diff_copy(&original, &actual_diffset).unwrap();
491+
let expected_changed = {
492+
let mut destination = vec![255; CHANGED_LEN];
493+
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
494+
destination
495+
};
344496

345497
assert_eq!(changed.as_slice(), expected_changed.as_slice());
498+
}
499+
500+
#[test]
501+
fn test_expanded_account_data() {
502+
const CHANGED_LEN: usize = 120;
503+
504+
let original = vec![0; 100];
505+
let changed = {
506+
let mut copy = original.clone();
507+
copy.resize(CHANGED_LEN, 0); // new bytes are zero-initialized
508+
509+
// | 11 | 12 | 13 | 14 |
510+
copy[11..=14].copy_from_slice(&0x01020304u32.to_le_bytes());
511+
// | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
512+
copy[71..=78].copy_from_slice(&0x0102030405060708u64.to_le_bytes());
513+
copy
514+
};
515+
516+
let actual_diff = compute_diff(&original, &changed);
517+
518+
let actual_diffset = DiffSet::try_new(&actual_diff).unwrap();
519+
520+
// When an account expands, the extra bytes at the end become part of the diff, even if
521+
// all of them are zeroes, that is why (100, &[0; 32]) is passed as additional_changes to
522+
// the following function.
523+
//
524+
// TODO (snawaz): we could optimize compute_diff to not include the zero bytes which are
525+
// part of the expansion.
526+
let expected_diff = get_example_expected_diff(CHANGED_LEN, vec![(100, &[0; 20])]);
527+
528+
assert_eq!(actual_diff.len(), 4 + 4 + (8 + 8) + (4 + 8) + (4 + 4 + 20));
529+
assert_eq!(actual_diff.as_slice(), expected_diff.as_slice());
346530

347531
let expected_changed = {
348-
let mut destination = vec![255; original.len()];
349-
merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
532+
let mut destination = vec![255; CHANGED_LEN];
533+
let unwritten = merge_diff_copy(&mut destination, &original, &actual_diffset).unwrap();
534+
535+
// TODO (snawaz): unwritten == &mut [], is because currently the expanded bytes are part of the diff.
536+
// Once compute_diff is optimized further, written must be &mut [0; 20].
537+
assert_eq!(unwritten, &mut []);
538+
350539
destination
351540
};
352541

src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ pub enum DlpError {
8181
UndelegateBufferAlreadyInitialized = 36,
8282
#[error("Undelegate buffer PDA immutable")]
8383
UndelegateBufferImmutable = 37,
84+
#[error("An infallible error is encountered possibly due to logic error")]
85+
InfallibleError = 100,
8486
}
8587

8688
impl From<DlpError> for ProgramError {

0 commit comments

Comments
 (0)