Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 317 additions & 2 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import concurrent.futures
import itertools
import uuid
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections import defaultdict
from concurrent.futures import Future
from datetime import datetime
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Callable, Dict, Generic, List, Optional, Set, Tuple

Expand Down Expand Up @@ -57,7 +58,7 @@
from pyiceberg.partitioning import (
PartitionSpec,
)
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRefType
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
Operation,
Snapshot,
Expand Down Expand Up @@ -88,6 +89,7 @@

if TYPE_CHECKING:
from pyiceberg.table import Transaction
from pyiceberg.table.metadata import TableMetadata


def _new_manifest_file_name(num: int, commit_uuid: uuid.UUID) -> str:
Expand Down Expand Up @@ -794,6 +796,269 @@ def merge_manifests(self, manifests: List[ManifestFile]) -> List[ManifestFile]:
return merged_manifests


# Branch Merge Strategy Enums and Classes


class BranchMergeStrategy(Enum):
"""Enumeration of available branch merge strategies for Iceberg tables.

This enum defines the different ways branches can be merged, similar to Git merge strategies.
Each strategy has different implications for the resulting commit history and snapshot structure.
"""

MERGE = "merge"
"""The classic approach. Creates a new "merge commit" to join two branches, preserving the history of both."""

REBASE = "rebase"
"""Re-writes history by placing the commits from one branch on top of another, resulting in a linear history."""

SQUASH = "squash"
"""Condenses all commits from a feature branch into a single, clean commit on the target branch."""

CHERRY_PICK = "cherry_pick"
"""Selects and applies a specific, individual commit from one branch to another."""

FAST_FORWARD = "fast_forward"
"""A special type of merge where the target branch pointer is simply moved forward to point to the source branch's head, without creating a merge commit. This is only possible if there are no new commits on the target branch."""


class _BaseBranchMergeStrategy(ABC):
"""Base class for branch merge strategy implementations."""

@abstractmethod
def merge(
self,
source_branch: str,
target_branch: str,
transaction: "Transaction",
merge_commit_message: Optional[str] = None,
) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]:
"""
Execute the merge strategy.

Args:
source_branch: Name of the source branch
target_branch: Name of the target branch
transaction: The table transaction
merge_commit_message: Optional custom merge message
Returns:
Tuple of (updates, requirements) for the merge operation
"""
...

def _find_common_ancestor(
self, source_ref: SnapshotRef, target_ref: SnapshotRef, table_metadata: "TableMetadata"
) -> Optional[Snapshot]:
"""Find the common ancestor snapshot between two branches."""
source_snapshot = table_metadata.snapshot_by_id(source_ref.snapshot_id)
target_snapshot = table_metadata.snapshot_by_id(target_ref.snapshot_id)

if not source_snapshot or not target_snapshot:
return None

# Build ancestor chains
source_ancestors = set()
current: Optional[Snapshot] = source_snapshot
while current:
source_ancestors.add(current.snapshot_id)
current = table_metadata.snapshot_by_id(current.parent_snapshot_id) if current.parent_snapshot_id else None

# Find first common ancestor
current = target_snapshot
while current:
if current.snapshot_id in source_ancestors:
return current
current = table_metadata.snapshot_by_id(current.parent_snapshot_id) if current.parent_snapshot_id else None

return None

def _is_fast_forward_possible(
self, source_ref: SnapshotRef, target_ref: SnapshotRef, table_metadata: "TableMetadata"
) -> bool:
"""Check if a fast-forward merge is possible (target hasn't diverged)."""
target_snapshot = table_metadata.snapshot_by_id(target_ref.snapshot_id)
if not target_snapshot:
return False

# Walk up source branch ancestry to see if target snapshot is an ancestor
source_snapshot = table_metadata.snapshot_by_id(source_ref.snapshot_id)
current = source_snapshot
while current:
if current.snapshot_id == target_snapshot.snapshot_id:
return True
current = table_metadata.snapshot_by_id(current.parent_snapshot_id) if current.parent_snapshot_id else None

return False


class _SquashMergeStrategy(_BaseBranchMergeStrategy):
"""Squash merge strategy: combine all changes from source branch into single commit."""

def merge(
self,
source_branch: str,
target_branch: str,
transaction: "Transaction",
merge_commit_message: Optional[str] = None,
) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]:
"""Execute squash merge by creating single snapshot with combined changes."""
source_ref = transaction.table_metadata.refs[source_branch]
target_ref = transaction.table_metadata.refs[target_branch]

# Check if fast-forward is possible
if self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata):
# Simple fast-forward: just update target to point to source
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)
return update, requirement

# For now, implement simple case - more complex merging would require
# analyzing data files, manifests, etc.
source_snapshot = transaction.table_metadata.snapshot_by_id(source_ref.snapshot_id)
if not source_snapshot:
raise ValueError(f"Source snapshot not found for branch {source_branch}")

# Create new snapshot that represents the squashed changes
# This is a simplified implementation - full implementation would need to:
# 1. Analyze data files from both branches
# 2. Resolve any conflicts
# 3. Create appropriate manifests
# For now, we'll update the branch reference
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)

return update, requirement


class _MergeStrategy(_BaseBranchMergeStrategy):
"""Merge strategy: create merge commit with two parents, preserving history of both branches."""

def merge(
self,
source_branch: str,
target_branch: str,
transaction: "Transaction",
merge_commit_message: Optional[str] = None,
) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]:
"""Execute three-way merge creating merge commit with two parents."""
source_ref = transaction.table_metadata.refs[source_branch]
target_ref = transaction.table_metadata.refs[target_branch]

# Check if fast-forward is possible
if self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata):
# Fast-forward: just update target to point to source
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)
return update, requirement

# Find common ancestor
common_ancestor = self._find_common_ancestor(source_ref, target_ref, transaction.table_metadata)
if not common_ancestor:
raise ValueError(f"No common ancestor found between {source_branch} and {target_branch}")

# This is where we would implement the actual three-way merge logic
# For now, implement a simplified version similar to squash

# Simplified: point target to source (would need proper merge logic)
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)

return update, requirement


class _RebaseMergeStrategy(_BaseBranchMergeStrategy):
"""Rebase merge strategy: replay commits from source branch on target."""

def merge(
self,
source_branch: str,
target_branch: str,
transaction: "Transaction",
merge_commit_message: Optional[str] = None,
) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]:
"""Execute rebase merge by replaying source commits on target."""
source_ref = transaction.table_metadata.refs[source_branch]
target_ref = transaction.table_metadata.refs[target_branch]

# Check if fast-forward is possible
if self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata):
# Fast-forward: just update target to point to source
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)
return update, requirement

# For rebase, we would need to:
# 1. Find commits since divergence
# 2. Replay each commit on top of target
# 3. Update source branch to new history
# This is the most complex strategy

# Simplified implementation for now
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)

return update, requirement


class _CherryPickStrategy(_BaseBranchMergeStrategy):
"""Cherry-pick strategy: select and apply a specific commit from one branch to another."""

def merge(
self,
source_branch: str,
target_branch: str,
transaction: "Transaction",
merge_commit_message: Optional[str] = None,
) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]:
"""Execute cherry-pick by applying specific commit to target branch."""
source_ref = transaction.table_metadata.refs[source_branch]
target_ref = transaction.table_metadata.refs[target_branch]

# For cherry-pick, we apply just the latest commit from source to target
# This creates a new snapshot with target as parent but source's changes
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)

return update, requirement


class _FastForwardStrategy(_BaseBranchMergeStrategy):
"""Fast-forward strategy: move target branch pointer forward without creating merge commit."""

def merge(
self,
source_branch: str,
target_branch: str,
transaction: "Transaction",
merge_commit_message: Optional[str] = None,
) -> Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]:
"""Execute fast-forward merge by moving target branch pointer."""
source_ref = transaction.table_metadata.refs[source_branch]
target_ref = transaction.table_metadata.refs[target_branch]

# Verify fast-forward is possible
if not self._is_fast_forward_possible(source_ref, target_ref, transaction.table_metadata):
raise ValueError(f"Fast-forward merge not possible between {source_branch} and {target_branch}")

# Fast-forward: just update target to point to source
update = (SetSnapshotRefUpdate(ref_name=target_branch, snapshot_id=source_ref.snapshot_id, type="branch"),)
requirement = (AssertRefSnapshotId(ref=target_branch, snapshot_id=target_ref.snapshot_id),)

return update, requirement


def _get_merge_strategy_impl(strategy: BranchMergeStrategy) -> _BaseBranchMergeStrategy:
"""Get the implementation for a given merge strategy."""
strategy_map = {
BranchMergeStrategy.MERGE: _MergeStrategy(),
BranchMergeStrategy.SQUASH: _SquashMergeStrategy(),
BranchMergeStrategy.REBASE: _RebaseMergeStrategy(),
BranchMergeStrategy.CHERRY_PICK: _CherryPickStrategy(),
BranchMergeStrategy.FAST_FORWARD: _FastForwardStrategy(),
}
return strategy_map[strategy]


class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
"""
Run snapshot management operations using APIs.
Expand Down Expand Up @@ -915,6 +1180,56 @@ def remove_branch(self, branch_name: str) -> ManageSnapshots:
"""
return self._remove_ref_snapshot(ref_name=branch_name)

def merge_branch(
self,
source_branch: str,
target_branch: str = "main",
strategy: BranchMergeStrategy = BranchMergeStrategy.MERGE,
merge_commit_message: Optional[str] = None,
delete_source_branch: bool = False,
) -> ManageSnapshots:
"""
Merge a source branch into a target branch using the specified merge strategy.

Args:
source_branch (str): Name of the branch to merge from
target_branch (str): Name of the branch to merge into (default: "main")
strategy (BranchMergeStrategy): The merge strategy to use
merge_commit_message (Optional[str]): Custom message for the merge commit
delete_source_branch (bool): Whether to delete the source branch after merge (default: False)
Returns:
This for method chaining
"""
# Validate branches exist
if source_branch not in self._transaction.table_metadata.refs:
raise ValueError(f"Source branch '{source_branch}' does not exist")
if target_branch not in self._transaction.table_metadata.refs:
raise ValueError(f"Target branch '{target_branch}' does not exist")

if source_branch == target_branch:
raise ValueError("Cannot merge a branch into itself")

# Get the appropriate merge strategy implementation
merge_strategy_impl = _get_merge_strategy_impl(strategy)

# Execute the merge
updates, requirements = merge_strategy_impl.merge(
source_branch=source_branch,
target_branch=target_branch,
transaction=self._transaction,
merge_commit_message=merge_commit_message,
)

self._updates += updates
self._requirements += requirements

# Delete source branch if requested
if delete_source_branch:
# Use remove_branch to delete the source branch after merge
self._remove_ref_snapshot(ref_name=source_branch)

return self


class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
"""Expire snapshots by ID.
Expand Down
Loading