AlgoPlus v0.1.0
Loading...
Searching...
No Matches
rabin_karp.h
1#ifndef RABIN_KARP_H
2#define RABIN_KARP_H
3
4#ifdef __cplusplus
5#include <string>
6#include <vector>
7#endif
8
9namespace {
10const int base = 26;
11const int modulus = 1e9 + 7;
12
20size_t compute_hash(const std::string& str, size_t start, size_t end) {
21 size_t curr_mod = 1;
22 size_t hash_value = 0;
23 for (size_t i = start; i < end; i++) {
24 hash_value = (hash_value + (size_t(str[end - i - 1]) * curr_mod) % modulus) % modulus;
25 curr_mod = (curr_mod * base) % modulus;
26 }
27 return hash_value;
28}
29
46bool check_collision(const std::string& str1, size_t start1, const std::string& str2, size_t start2,
47 size_t length) {
48 for (size_t i = 0; i < length; ++i) {
49 if (str1[start1 + i] != str2[start2 + i]) {
50 return false;
51 }
52 }
53 return true;
54}
55} // namespace
56
70std::vector<size_t> rabin_karp(const std::string& text, const std::string& pattern) {
71 std::vector<size_t> result;
72 size_t pattern_length = pattern.length();
73 size_t text_length = text.length();
74
75 if (pattern_length == 0) { // if pattern is empty, it can be found at every
76 // index including the end of the text
77 for (size_t i = 0; i <= text_length; i++) {
78 result.push_back(i);
79 }
80 return result;
81 }
82
83 if (text_length < pattern_length) { // if text is shorter than pattern,
84 // pattern can not be found
85 return result;
86 }
87
88 // calculate the hash of the pattern and the hash of the first pattern_length
89 // characters of the text
90 size_t pattern_hash = compute_hash(pattern, 0, pattern_length);
91 size_t text_hash = compute_hash(text, 0, pattern_length);
92
93 // the highest power used in the hash calculation of the pattern
94 size_t power = 1;
95 for (int i = 0; i < pattern_length - 1; ++i)
96 power = (power * base) % modulus;
97
98 for (size_t i = 0; i <= text_length - pattern_length; ++i) {
99 if (pattern_hash == text_hash && check_collision(text, i, pattern, 0, pattern_length)) {
100 result.push_back(i);
101 }
102
103 if (i < text_length - pattern_length) {
104 text_hash =
105 (base * (text_hash - ((size_t)text[i] * power % modulus) + modulus) % modulus +
106 (size_t)text[i + pattern_length]) %
107 modulus;
108 }
109 }
110
111 return result;
112}
113
114#endif