支持元素更新的D元最小堆模版

做15-445的时候需要实现一个支持元素更新的堆。

STL中的堆是不支持元素更新的,但是我们可以利用元素上浮和元素下沉的操作实现元素更新。具体做法也很简单,根据元素是增大还是减小,调用对应操作保证符合堆的性质即可。

不幸的是STL又没有暴露元素上浮和元素下沉的接口,于是只能自己写一个了。

为了能够更新元素的值,给每个元素都附带了一个key标签,凭key去更新元素值,所以最终整个数据结构演变成了类似哈希表的结构。

实现时顺便实现了D元堆,元数作为模版参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/**
* KeyedMinHeap is a min-heap that supports the following operations:
* - Insert: Insert a key with a score into the heap
* - Update: Update the score of a key in the heap
* - TopKey: Get the key with the minimum score
* - TopScore: Get the minimum score
* - Pop: Remove the key with the minimum score
* - Erase: Remove a key from the heap
*/
template <typename Key, std::size_t ARITY = 4>
class KeyedMinHeap {
public:
using heap_index_t = std::size_t;
using score_t = std::size_t;

public:
KeyedMinHeap() = default;

auto Insert(const Key &key, score_t score) {
keys_map_[key] = heap_.size();
heap_.emplace_back(key, score);
BubbleUp(heap_.size() - 1);
}

auto Increment(const Key &key) {
auto index_iter = keys_map_.find(key);
assert(index_iter != keys_map_.end() && "Key not found");
heap_index_t h_index = index_iter->second;

heap_[h_index].second++;
BubbleDown(h_index);
}

auto Update(const Key &key, score_t new_score) {
auto index_iter = keys_map_.find(key);
assert(index_iter != keys_map_.end() && "Key not found");
heap_index_t h_index = index_iter->second;

score_t old_score = heap_[h_index].second;
if (old_score == new_score) {
return;
}

heap_[h_index].second = new_score;

if (new_score < old_score) {
BubbleUp(h_index);
} else {
BubbleDown(h_index);
}
}

auto TopKey() const -> Key {
assert(!heap_.empty() && "Heap is empty");
return heap_.front().first;
}

auto TopScore() const -> score_t {
assert(!heap_.empty() && "Heap is empty");
return heap_.front().second;
}

auto Pop() -> void {
assert(!heap_.empty() && "Heap is empty");

auto top_key = heap_.front().first;
auto last_key = heap_.back().first;

keys_map_.erase(top_key);
keys_map_[last_key] = 0;

std::swap(heap_.front(), heap_.back());
heap_.pop_back();

BubbleDown(0);
}

auto Erase(const Key &key) -> void {
auto entry = keys_map_.find(key);
assert(entry != keys_map_.end() && "Key not found");

heap_index_t index = entry->second;

keys_map_.erase(entry);
keys_map_[heap_.back().first] = index;

std::swap(heap_[index], heap_.back());
heap_.pop_back();

BubbleUp(index);
BubbleDown(index);
}

auto Size() const -> std::size_t { return heap_.size(); }

auto Empty() const -> bool { return heap_.empty(); }

private:
auto ParentIndex(heap_index_t h_index) const -> heap_index_t { return (h_index - 1) / ARITY; }

// Returns the index of the child_number-th child of the node at index
// The children are numbered from 1 to ARITY
auto ChildIndex(heap_index_t h_index, std::size_t child_number) const -> heap_index_t {
return ARITY * h_index + child_number;
}

void BubbleUp(heap_index_t h_index) {
while (h_index != 0) {
std::size_t parent = ParentIndex(h_index);

auto &parent_entry = heap_[parent];
auto &entry = heap_[h_index];

if (entry.second < parent_entry.second) {
std::swap(keys_map_[entry.first], keys_map_[parent_entry.first]);
std::swap(entry, parent_entry);
h_index = parent;
} else {
break;
}
}
}

void BubbleDown(heap_index_t h_index) {
while (true) {
heap_index_t smallest = h_index;
for (std::size_t i = 1; i <= ARITY; ++i) {
heap_index_t child = ChildIndex(h_index, i);
if (child < heap_.size() && heap_[child].second < heap_[smallest].second) {
smallest = child;
}
}
if (smallest != h_index) {
const auto &entry = heap_[h_index];
const auto &smallest_entry = heap_[smallest];
std::swap(keys_map_[entry.first], keys_map_[smallest_entry.first]);
std::swap(heap_[h_index], heap_[smallest]);
h_index = smallest;
} else {
break;
}
}
}

private:
std::unordered_map<Key, heap_index_t> keys_map_;
std::vector<std::pair<Key, score_t>> heap_;
};