Skip to content

Commit 3a02506

Browse files
ongspxmtjgurwara99raklaptudirm
authored
merge: feat: dijkstra closest distance implementation (#415)
* refactor * feat: implement dijkstra * test: adding multi path test cases * fix: varible naming in test * refactor * feat: generalize heap * feat: new implementation * cleanup * codespell * nit * fix conflict * Update graph/dijkstra.go Co-authored-by: Taj <[email protected]> * Update graph/dijkstra.go Co-authored-by: Taj <[email protected]> * (refactor) Co-authored-by: Taj <[email protected]> Co-authored-by: Rak Laptudirm <[email protected]>
1 parent 153ee2c commit 3a02506

File tree

3 files changed

+217
-20
lines changed

3 files changed

+217
-20
lines changed

graph/dijkstra.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package graph
2+
3+
import "github.com/TheAlgorithms/Go/sort"
4+
5+
type Item struct {
6+
node int
7+
dist int
8+
}
9+
10+
func (a Item) More(b interface{}) bool {
11+
// reverse direction for minheap
12+
return a.dist < b.(Item).dist
13+
}
14+
func (a Item) Idx() int {
15+
return a.node
16+
}
17+
18+
func (g *Graph) Dijkstra(start, end int) (int, bool) {
19+
visited := make(map[int]bool)
20+
nodes := make(map[int]*Item)
21+
22+
nodes[start] = &Item{
23+
dist: 0,
24+
node: start,
25+
}
26+
pq := sort.MaxHeap{}
27+
pq.Init(nil)
28+
pq.Push(*nodes[start])
29+
30+
visit := func(curr Item) {
31+
visited[curr.node] = true
32+
for n, d := range g.edges[curr.node] {
33+
if visited[n] {
34+
continue
35+
}
36+
37+
item := nodes[n]
38+
dist2 := curr.dist + d
39+
if item == nil {
40+
nodes[n] = &Item{node: n, dist: dist2}
41+
pq.Push(*nodes[n])
42+
} else if item.dist > dist2 {
43+
item.dist = dist2
44+
pq.Update(*item)
45+
}
46+
}
47+
}
48+
49+
for pq.Size() > 0 {
50+
curr := pq.Pop().(Item)
51+
if curr.node == end {
52+
break
53+
}
54+
visit(curr)
55+
}
56+
57+
item := nodes[end]
58+
if item == nil {
59+
return -1, false
60+
}
61+
return item.dist, true
62+
}

graph/dijkstra_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package graph
2+
3+
import (
4+
"testing"
5+
)
6+
7+
var tc_dijkstra = []struct {
8+
name string
9+
edges [][]int
10+
node0 int
11+
node1 int
12+
expected int
13+
}{
14+
{
15+
"straight line graph",
16+
[][]int{{0, 1, 5}, {1, 2, 2}},
17+
0, 2, 7,
18+
},
19+
{
20+
"unconnected node",
21+
[][]int{{0, 1, 5}},
22+
0, 2, -1,
23+
},
24+
{
25+
"double paths",
26+
[][]int{{0, 1, 5}, {1, 3, 5}, {0, 2, 5}, {2, 3, 4}},
27+
0, 3, 9,
28+
},
29+
{
30+
"double paths extended",
31+
[][]int{{0, 1, 5}, {1, 3, 5}, {0, 2, 5}, {2, 3, 4}, {3, 4, 1}},
32+
0, 4, 10,
33+
},
34+
}
35+
36+
func TestDijkstra(t *testing.T) {
37+
for _, tc := range tc_dijkstra {
38+
t.Run(tc.name, func(t *testing.T) {
39+
var graph Graph
40+
for _, edge := range tc.edges {
41+
graph.AddWeightedEdge(edge[0], edge[1], edge[2])
42+
}
43+
44+
actual, _ := graph.Dijkstra(tc.node0, tc.node1)
45+
if actual != tc.expected {
46+
t.Errorf("expected %d, got %d, from node %d to %d, with %v",
47+
tc.expected, actual, tc.node0, tc.node1, tc.edges)
48+
}
49+
})
50+
}
51+
}

sort/heapsort.go

Lines changed: 104 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,134 @@
11
package sort
22

3-
type maxHeap struct {
4-
slice []int
3+
type MaxHeap struct {
4+
slice []Comparable
55
heapSize int
6+
indices map[int]int
67
}
78

8-
func buildMaxHeap(slice []int) maxHeap {
9-
h := maxHeap{slice: slice, heapSize: len(slice)}
10-
for i := len(slice) / 2; i >= 0; i-- {
11-
h.MaxHeapify(i)
9+
func buildMaxHeap(slice0 []int) MaxHeap {
10+
var slice []Comparable
11+
for _, i := range slice0 {
12+
slice = append(slice, Int(i))
1213
}
14+
h := MaxHeap{}
15+
h.Init(slice)
1316
return h
1417
}
1518

16-
func (h maxHeap) MaxHeapify(i int) {
19+
func (h *MaxHeap) Init(slice []Comparable) {
20+
if slice == nil {
21+
slice = make([]Comparable, 0)
22+
}
23+
24+
h.slice = slice
25+
h.heapSize = len(slice)
26+
h.indices = make(map[int]int)
27+
h.Heapify()
28+
}
29+
30+
func (h MaxHeap) Heapify() {
31+
for i, v := range h.slice {
32+
h.indices[v.Idx()] = i
33+
}
34+
for i := h.heapSize / 2; i >= 0; i-- {
35+
h.heapifyDown(i)
36+
}
37+
}
38+
39+
func (h *MaxHeap) Pop() Comparable {
40+
if h.heapSize == 0 {
41+
return nil
42+
}
43+
44+
i := h.slice[0]
45+
h.heapSize--
46+
47+
h.slice[0] = h.slice[h.heapSize]
48+
h.updateidx(0)
49+
h.heapifyDown(0)
50+
51+
h.slice = h.slice[0:h.heapSize]
52+
return i
53+
}
54+
55+
func (h *MaxHeap) Push(i Comparable) {
56+
h.slice = append(h.slice, i)
57+
h.updateidx(h.heapSize)
58+
h.heapifyUp(h.heapSize)
59+
h.heapSize++
60+
}
61+
62+
func (h MaxHeap) Size() int {
63+
return h.heapSize
64+
}
65+
66+
func (h MaxHeap) Update(i Comparable) {
67+
h.slice[h.indices[i.Idx()]] = i
68+
h.heapifyUp(h.indices[i.Idx()])
69+
h.heapifyDown(h.indices[i.Idx()])
70+
}
71+
72+
func (h MaxHeap) updateidx(i int) {
73+
h.indices[h.slice[i].Idx()] = i
74+
}
75+
76+
func (h MaxHeap) heapifyUp(i int) {
77+
if i == 0 {
78+
return
79+
}
80+
p := i / 2
81+
82+
if h.slice[i].More(h.slice[p]) {
83+
h.slice[i], h.slice[p] = h.slice[p], h.slice[i]
84+
h.updateidx(i)
85+
h.updateidx(p)
86+
h.heapifyUp(p)
87+
}
88+
}
89+
90+
func (h MaxHeap) heapifyDown(i int) {
1791
l, r := 2*i+1, 2*i+2
1892
max := i
1993

20-
if l < h.size() && h.slice[l] > h.slice[max] {
94+
if l < h.heapSize && h.slice[l].More(h.slice[max]) {
2195
max = l
2296
}
23-
if r < h.size() && h.slice[r] > h.slice[max] {
97+
if r < h.heapSize && h.slice[r].More(h.slice[max]) {
2498
max = r
2599
}
26-
//log.Printf("MaxHeapify(%v): l,r=%v,%v; max=%v\t%v\n", i, l, r, max, h.slice)
27100
if max != i {
28101
h.slice[i], h.slice[max] = h.slice[max], h.slice[i]
29-
h.MaxHeapify(max)
102+
h.updateidx(i)
103+
h.updateidx(max)
104+
h.heapifyDown(max)
30105
}
31106
}
32107

33-
func (h maxHeap) size() int { return h.heapSize } // ???
108+
type Comparable interface {
109+
Idx() int
110+
More(interface{}) bool
111+
}
112+
type Int int
113+
114+
func (a Int) More(b interface{}) bool {
115+
return a > b.(Int)
116+
}
117+
func (a Int) Idx() int {
118+
return int(a)
119+
}
34120

35121
func HeapSort(slice []int) []int {
36122
h := buildMaxHeap(slice)
37-
//log.Println(slice)
38123
for i := len(h.slice) - 1; i >= 1; i-- {
39124
h.slice[0], h.slice[i] = h.slice[i], h.slice[0]
40125
h.heapSize--
41-
h.MaxHeapify(0)
42-
/*if i == len(h.slice)-1 || i == len(h.slice)-3 || i == len(h.slice)-5 {
43-
element := (i - len(h.slice)) * -1
44-
fmt.Println("Heap after removing ", element, " elements")
45-
fmt.Println(h.slice)
126+
h.heapifyDown(0)
127+
}
46128

47-
}*/
129+
res := []int{}
130+
for _, i := range h.slice {
131+
res = append(res, int(i.(Int)))
48132
}
49-
return h.slice
133+
return res
50134
}

0 commit comments

Comments
 (0)