00001 #ifndef LM_HH
00002 #define LM_HH
00003
00004 #include "bit/Trie.hh"
00005 #include "bit/CompressedArray.hh"
00006 #include "bit/FloatArray.hh"
00007 #include "SymbolMap.hh"
00008
00011 namespace bit {
00012
00020 class LM {
00021 public:
00022
00024 typedef Trie<CompressedArray> Trie;
00025
00027 typedef Trie::Iterator Iterator;
00028
00030 typedef SymbolMap<std::string, int> SymbolMap;
00031
00033 LM()
00034 {
00035 reset();
00036 }
00037
00039 void reset()
00040 {
00041 m_symbol_map = SymbolMap();
00042 m_start_symbol = -1;
00043 m_end_symbol = -1;
00044 m_trie = Trie();
00045 m_backoff_arrays.clear();
00046 m_score_arrays.clear();
00047 m_previous_ngram.clear();
00048 }
00049
00051 unsigned int order() const
00052 {
00053 return m_score_arrays.size();
00054 }
00055
00057 u64 size() const
00058 {
00059 u64 size = 0;
00060 assert(m_score_arrays.size() == m_backoff_arrays.size());
00061 for (size_t i = 0; i < m_score_arrays.size(); i++) {
00062 size += m_score_arrays[i].data_len();
00063 size += m_backoff_arrays[i].data_len();
00064 }
00065 return size + m_trie.size();
00066 }
00067
00069 const FloatArray &score_array(unsigned int level) const
00070 {
00071 return m_score_arrays.at(level);
00072 }
00073
00075 const FloatArray &backoff_array(unsigned int level) const
00076 {
00077 return m_backoff_arrays.at(level);
00078 }
00079
00081 const CompressedArray &symbol_array(unsigned int level) const
00082 {
00083 return m_trie.symbol_array(level);
00084 }
00085
00087 const CompressedArray &pointer_array(unsigned int level) const
00088 {
00089 return m_trie.pointer_array(level);
00090 }
00091
00093 const CompressedArray &child_limit_array(unsigned int level) const
00094 {
00095 return m_trie.child_limit_array(level);
00096 }
00097
00103 void
00104 read_arpa(FILE *file, const std::string &sentence_start_str = "<s>",
00105 const std::string &sentence_end_str = "</s>",
00106 bool verbose = false);
00107
00111 void write_arpa(FILE *file) const;
00112
00117 void write(FILE *file) const;
00118
00123 void read(FILE *file);
00124
00133 void linear_quantization(unsigned int bits);
00134
00138 void compress_trie(unsigned int level)
00139 {
00140 m_trie.compress(level);
00141 }
00142
00144 void compress_trie()
00145 {
00146 m_trie.compress();
00147 }
00148
00152 void uncompress_trie(unsigned int level)
00153 {
00154 m_trie.uncompress(level);
00155 }
00156
00158 void uncompress_trie()
00159 {
00160 m_trie.uncompress();
00161 }
00162
00171 void separate_leafs(unsigned int level);
00172
00181 void unseparate_leafs(unsigned int level);
00182
00191 void
00192 insert_ngram(const std::vector<int> &ngram, float score, float backoff);
00193
00202 void
00203 insert_ngram(const std::string &str, float score, float backoff);
00204
00209 void set_start_symbol(const std::string &str)
00210 {
00211 if (m_start_symbol >= 0)
00212 throw bit::invalid_call("bit::LM::set_start_symbol() called again");
00213 m_start_symbol = m_symbol_map.insert(str);
00214 }
00215
00220 void set_end_symbol(const std::string &str)
00221 {
00222 if (m_end_symbol >= 0)
00223 throw bit::invalid_call("bit::LM::set_end_symbol() called again");
00224 m_end_symbol = m_symbol_map.insert(str);
00225 }
00226
00228 int start_symbol() const {
00229 return m_start_symbol;
00230 }
00231
00233 int end_symbol() const {
00234 return m_end_symbol;
00235 }
00236
00238 const SymbolMap &symbol_map() const
00239 {
00240 return m_symbol_map;
00241 }
00242
00247 template <class T>
00248 std::string
00249 ngram_str(const std::vector<T> &vec) const
00250 {
00251 assert(!vec.empty());
00252 std::string str;
00253 for (size_t o = 0; o < vec.size(); o++) {
00254 if (o != 0)
00255 str.append(" ");
00256 str.append(m_symbol_map.at(vec[o]));
00257 }
00258 return str;
00259 }
00260
00262 Iterator root() const
00263 {
00264 return Iterator(m_trie);
00265 }
00266
00270 float backoff(const Iterator &it) const
00271 {
00272 if (it.is_root())
00273 throw invalid_call("lm::LM::backoff() called at root");
00274 u32 index = it.child_limit_index();
00275 if (index == max_u32)
00276 return 0;
00277 return backoff(it.length() - 1, index);
00278 }
00279
00287 float backoff(unsigned int level, u64 index) const
00288 {
00289 if (level >= m_score_arrays.size())
00290 throw bit::invalid_argument("bit::LM::backoff() level too high");
00291 const FloatArray &score_array = m_score_arrays.at(level);
00292 const FloatArray &backoff_array = m_backoff_arrays.at(level);
00293 if (index >= score_array.num_elems())
00294 throw bit::invalid_argument("bit::LM::backoff() index too high");
00295 if (index >= backoff_array.num_elems())
00296 return 0;
00297 return backoff_array.get(index);
00298 }
00299
00301 float score(const Iterator &it) const
00302 {
00303 u64 index = it.symbol_index();
00304 unsigned int level = it.length() - 1;
00305 return m_score_arrays.at(level).get(index);
00306 }
00307
00316 float walk(Iterator &it, int symbol) const
00317 {
00318 float score = 0;
00319 while (!it.goto_child(symbol)) {
00320 assert(it.length() > 0);
00321 score += backoff(it);
00322 it.goto_backoff_full();
00323 }
00324 u64 index = it.symbol_index();
00325 score += m_score_arrays.at(it.length() - 1).get(index);
00326 if (it.num_children() == 0) {
00327 assert(backoff(it) == 0);
00328 it.goto_backoff_full();
00329 }
00330 return score;
00331 }
00332
00333 private:
00334
00338 int compare_ngrams(const std::vector<int> &a, const std::vector<int> &b)
00339 {
00340 std::vector<int>::size_type o = 0;
00341 while (1) {
00342 if (o == a.size() && o == b.size())
00343 return 0;
00344 if (o == a.size())
00345 return -1;
00346 if (o == b.size())
00347 return 1;
00348 if (a[o] < b[o])
00349 return -1;
00350 if (a[o] > b[o])
00351 return 1;
00352 o++;
00353 }
00354 }
00355
00356 private:
00357
00359 SymbolMap m_symbol_map;
00360
00362 int m_start_symbol;
00363
00365 int m_end_symbol;
00366
00368 Trie m_trie;
00369
00371 std::vector<FloatArray> m_backoff_arrays;
00372
00374 std::vector<FloatArray> m_score_arrays;
00375
00377 std::vector<int> m_previous_ngram;
00378 };
00379
00380 };
00381
00382 #endif