Skip to content

Commit 23e9b43

Browse files
committed
core: move all capser code inside casper_util
1 parent 5379e73 commit 23e9b43

File tree

2 files changed

+53
-41
lines changed

2 files changed

+53
-41
lines changed

core/blockchain.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,8 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
884884
bc.mu.Lock()
885885
defer bc.mu.Unlock()
886886

887+
currentBlock := bc.CurrentBlock()
888+
localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64())
887889
externTd := new(big.Int).Add(block.Difficulty(), ptd)
888890

889891
// Irrelevant of the canonical status, write the block itself to the database
@@ -954,20 +956,14 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
954956
if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
955957
return NonStatTy, err
956958
}
957-
958-
currentBlock := bc.CurrentBlock()
959-
localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64())
960959
// If the total difficulty is higher than our known, add it to the canonical chain
961-
setNewHead := false
960+
setNewHead := externTd.Cmp(localTd) > 0
962961
if bc.chainConfig.IsCasper(block.Number()) {
963-
// Use Casper's fork choice rule
964-
localScore := bc.getScore(currentBlock)
965-
externScore := bc.getScore(block)
966-
setNewHead = externScore.Cmp(localScore) > 0 && bc.safeForLastFinalizedBlock(block, state)
967-
} else {
968-
setNewHead = externTd.Cmp(localTd) > 0
962+
setNewHead, err = bc.acceptNewCasperBlock(currentBlock, block, state)
963+
if err != nil {
964+
return NonStatTy, err
965+
}
969966
}
970-
971967
// Try to reduce the vulnerability to selfish mining.
972968
// Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf
973969
if !setNewHead && externTd.Cmp(localTd) == 0 {

core/casper_util.go

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,69 +27,85 @@ import (
2727
"github.com/ethereum/go-ethereum/log"
2828
)
2929

30-
// getScore returns the score of a block
31-
func (bc *BlockChain) getScore(block *types.Block) *big.Int {
32-
score := bc.GetTd(block.Hash(), block.NumberU64())
33-
34-
if bc.chainConfig.IsCasper(block.Number()) {
35-
if scored, casperScore := bc.getCasperScore(block); scored {
36-
casperScale := new(big.Int).Exp(big.NewInt(int64(10)), big.NewInt(int64(40)), nil)
37-
casperScore.Mul(casperScore, casperScale)
38-
score.Add(score, casperScore)
30+
func (bc *BlockChain) acceptNewCasperBlock(currentBlock *types.Block, newBlock *types.Block, newState *state.StateDB) (bool, error) {
31+
currentScore, err := bc.getScore(currentBlock)
32+
if err != nil {
33+
return false, err
34+
}
35+
newScore, err := bc.getScore(newBlock)
36+
if err != nil {
37+
return false, err
38+
}
39+
if newScore.Cmp(currentScore) > 0 {
40+
// Check if we will revert any finalized block in the Casper case
41+
safe, err := bc.safeForLastFinalizedBlock(newBlock, newState)
42+
if err != nil {
43+
return false, err
3944
}
45+
return safe, nil
4046
}
41-
return score
47+
return false, nil
4248
}
4349

44-
// getCasperScore returns the score of a block from Casper's perspective
45-
func (bc *BlockChain) getCasperScore(block *types.Block) (bool, *big.Int) {
50+
// getScore returns the score of a block from Casper's perspective
51+
func (bc *BlockChain) getScore(block *types.Block) (*big.Int, error) {
52+
score := bc.GetTd(block.Hash(), block.NumberU64())
53+
casperScore, err := bc.getLastJustifiedEpoch(block)
54+
if err != nil {
55+
return nil, err
56+
}
57+
casperScale := new(big.Int).Exp(big.NewInt(10), big.NewInt(40), nil)
58+
casperScore.Mul(casperScore, casperScale)
59+
score.Add(score, casperScore)
60+
return score, nil
61+
}
62+
63+
// getLastJustifiedEpoch returns the last justified epoch for a given block
64+
func (bc *BlockChain) getLastJustifiedEpoch(block *types.Block) (*big.Int, error) {
4665
state, err := bc.StateAt(block.Root())
4766
if err != nil {
48-
return false, nil
67+
return nil, err
4968
}
5069
stateBackend := NewStateBackend(block, state, bc)
5170
contract, err := casper.New(stateBackend)
5271
if err != nil {
5372
log.Warn("Failed to get Casper contract", "err", err)
54-
return false, nil
73+
return nil, err
5574
}
5675
justified, err := contract.GetLastJustifiedEpoch(&bind.CallOpts{})
5776
if err != nil {
5877
log.Warn("Failed to get current chain status from Casper", "err", err)
59-
return false, nil
78+
return nil, err
6079
}
61-
return true, justified
80+
return justified, nil
6281
}
6382

6483
// safeForLastFinalizedBlock returns true if the new head will NOT revert the last finalized block
65-
func (bc *BlockChain) safeForLastFinalizedBlock(newBlock *types.Block, newState *state.StateDB) bool {
84+
func (bc *BlockChain) safeForLastFinalizedBlock(newBlock *types.Block, newState *state.StateDB) (bool, error) {
6685
stateBackend := NewStateBackend(newBlock, newState, bc)
6786
contract, err := casper.New(stateBackend)
6887
if err != nil {
6988
log.Warn("Failed to get Casper contract", "err", err)
70-
return false
89+
return false, err
7190
}
7291
blockNumber, err := contract.GetLastFinalizedEpoch(&bind.CallOpts{})
7392
if err != nil {
7493
log.Warn("Failed to get current chain status from Casper", "err", err)
75-
return false
94+
return false, err
7695
}
7796
hashBytes, err := contract.GetCheckpointHashes(&bind.CallOpts{}, blockNumber)
7897
if err != nil {
7998
log.Warn("Failed to get current chain status from Casper", "err", err)
80-
return false
99+
return false, err
81100
}
82101
blockHash := common.BytesToHash(hashBytes[:])
83102
parentBlock := bc.GetBlock(newBlock.ParentHash(), newBlock.NumberU64()-1)
84-
for {
85-
if parentBlock == nil || blockNumber.Cmp(parentBlock.Number()) > 0 {
86-
return false
87-
} else if parentBlock.Hash() == blockHash {
88-
// The last finalized block IS currentBlock's ancestor
89-
return true
90-
} else {
91-
parentBlock = bc.GetBlock(parentBlock.ParentHash(), parentBlock.NumberU64()-1)
92-
}
103+
// Find the correct block
104+
for ; parentBlock != nil && parentBlock.Number().Cmp(blockNumber) > 0; parentBlock = bc.GetBlock(parentBlock.ParentHash(), parentBlock.NumberU64()-1) {
105+
}
106+
// Compare its hash
107+
if parentBlock == nil {
108+
return false, nil
93109
}
94-
return false
110+
return parentBlock.Hash() == blockHash, nil
95111
}

0 commit comments

Comments
 (0)