00001 #ifndef KNESER_HH
00002 #define KNESER_HH
00003
00004 #include "bit/Array.hh"
00005 #include "bit/Trie.hh"
00006 #include "str/str.hh"
00007 #include "lm/SymbolMap.hh"
00008 #include "util/Progress.hh"
00009 #include "util/util.hh"
00010
00011 namespace bit {
00012
00020 class Kneser {
00021 public:
00022
00023 typedef Trie<Array> Trie;
00024 typedef Trie::Iterator Iterator;
00025 typedef SymbolMap<std::string, int> SymbolMap;
00026 typedef std::vector<float> FloatVec;
00027 typedef std::vector<int> IntVec;
00028 typedef std::vector<int> Ngram;
00029
00031 Kneser()
00032 {
00033 m_d1_weight_model = 1;
00034 m_d1_model = 1;
00035
00036 m_sentence_start_str = "<s>";
00037 m_sentence_end_str = "</s>";
00038 m_sentence_start_id = -1;
00039 m_sentence_end_id = -1;
00040 m_progress_skip = 53871;
00041 reserve_orders(10);
00042 }
00043
00045 void set_d1_model(int model)
00046 {
00047 m_d1_model = model;
00048 }
00049
00051 void set_d1_weight_model(int model)
00052 {
00053 m_d1_weight_model = model;
00054 }
00055
00057 const SymbolMap &symbol_map() const
00058 {
00059 return m_symbol_map;
00060 }
00061
00063 int sentence_start_id() const
00064 {
00065 return m_sentence_start_id;
00066 }
00067
00069 int sentence_end_id() const
00070 {
00071 return m_sentence_end_id;
00072 }
00073
00075 Iterator root() const
00076 {
00077 return Iterator(m_trie);
00078 }
00079
00081 u64 num_ngrams() const
00082 {
00083 u64 ret = 0;
00084 for (size_t i = 0; i < m_counts.size(); i++)
00085 ret += m_counts[i].size();
00086 return ret;
00087 }
00088
00090 u64 num_active_ngrams() const
00091 {
00092 u64 ret = 0;
00093 for (size_t i = 0; i < m_num_ngrams.size(); i++)
00094 ret += m_num_ngrams[i];
00095 return ret;
00096 }
00097
00102 template <class T>
00103 Iterator find(const std::vector<T> &vec) const
00104 {
00105 return m_trie.find(vec);
00106 }
00107
00112 Iterator find(const std::string &str) const
00113 {
00114 return m_trie.find(ngram(str));
00115 }
00116
00123 float ngram_prob(Iterator it) const
00124 {
00125 double ret = 1;
00126 if (it.length() == 0)
00127 throw bit::invalid_argument("bit::Kneser::prob_weight() at root");
00128 while (it.length() > 0) {
00129 if (it.length() != 1 || it.symbol() != m_sentence_start_id) {
00130 switch (m_d1_model) {
00131 case 0:
00132 ret *= prob_full(it);
00133 break;
00134 case 1:
00135 ret *= prob_abs_full(it);
00136 break;
00137 case 2:
00138 throw bit::invalid_argument(
00139 "bit::Kneser::ngram_prob() invalid m_d1_model");
00140 }
00141 }
00142 it.goto_parent();
00143 }
00144 return ret;
00145 }
00146
00153 float prob_beta_lower(Iterator it) const
00154 {
00155 assert(it.length() > 0);
00156
00157 double ret = 0;
00158 double scale = 1;
00159 u32 symbol = it.symbol();
00160 it.goto_parent();
00161 scale *= get_beta_interpolation_numerator(it) /
00162 get_beta_denominator(it);
00163
00164 while (!it.is_root()) {
00165 it.goto_backoff_full();
00166 if (it.goto_child(symbol)) {
00167 double numerator = 0;
00168 if (!is_pruned(it)) {
00169 numerator = get_beta_numerator(it) -
00170 get_beta_discount(it.length() - 1);
00171 it.goto_parent();
00172 ret += scale * numerator / get_beta_denominator(it);
00173 }
00174 else
00175 it.goto_parent();
00176 }
00177
00178 scale *= get_beta_interpolation_numerator(it) /
00179 get_beta_denominator(it);
00180 }
00181
00182
00183
00184
00185 ret += scale / (m_symbol_map.size() - 1);
00186
00187 return ret;
00188 }
00189
00191 Ngram ngram(const std::string &str) const
00192 {
00193 std::vector<std::string> symbols = str::split(str, " \t", true);
00194 Ngram ngram(symbols.size());
00195 for (size_t i = 0; i < symbols.size(); i++)
00196 ngram[i] = m_symbol_map.index(symbols[i]);
00197 return ngram;
00198 }
00199
00205 float prob_beta_full(const Iterator &it) const
00206 {
00207 assert(m_sentence_start_id >= 0);
00208 assert(it.length() > 0);
00209 if (it.symbol() == m_sentence_start_id)
00210 throw bit::invalid_argument(
00211 "bit::Kneser::prob_beta_full(): called for sentence start");
00212
00213 unsigned int len = it.length();
00214 double numerator = 0;
00215 if (!is_pruned(it))
00216 numerator = get_beta_numerator(it) -
00217 get_beta_discount(len - 1);
00218 double denominator = get_beta_denominator(it.parent());
00219 return numerator / denominator + prob_beta_lower(it);
00220 }
00221
00230 float prob_beta_full(Ngram ngram) const
00231 {
00232 if (ngram.empty())
00233 throw bit::invalid_argument(
00234 "bit::Kneser::prob_beta_full(Ngram): ngram empty");
00235 int symbol = ngram.back();
00236 if (symbol == m_sentence_start_id)
00237 throw bit::invalid_argument(
00238 "bit::Kneser::prob_beta_full(Ngram) sentence start symbol");
00239
00240 double scale = 1;
00241 double ret = 0;
00242 while (1) {
00243 assert(ngram.size() > 0);
00244 Iterator it = find(ngram);
00245
00246
00247 if (!it.is_root()) {
00248 Iterator parent = it.parent();
00249 double denominator = get_beta_denominator(parent);
00250 if (!is_pruned(it)) {
00251 double numerator = get_beta_numerator(it) -
00252 get_beta_discount(it.length() - 1);
00253 assert(numerator >= 0);
00254 ret += scale * numerator / denominator;
00255 }
00256 scale *= get_beta_interpolation_numerator(parent) / denominator;
00257 }
00258
00259
00260 else {
00261 if (it.length() == 1)
00262 scale *= get_beta_interpolation_numerator(root()) /
00263 get_beta_denominator(root());
00264 else {
00265 Ngram parent_ngram(ngram.begin(), ngram.end() - 1);
00266 Iterator parent_it = find(parent_ngram);
00267 if (!parent_it.is_root())
00268 scale *= get_beta_interpolation_numerator(parent_it) /
00269 get_beta_denominator(parent_it);
00270 }
00271 }
00272
00273 if (it.length() == 1)
00274 break;
00275
00276 ngram.erase(ngram.begin());
00277 }
00278
00279 ret += scale / (m_symbol_map.size() - 1);
00280
00281 return ret;
00282 }
00283
00289 float prob_lower(Iterator it) const
00290 {
00291 assert(it.length() > 0);
00292
00293 double ret = 0;
00294 double scale = 1;
00295 u32 symbol = it.symbol();
00296 it.goto_parent();
00297 scale *= interpolation(it);
00298
00299 while (!it.is_root()) {
00300 it.goto_backoff_full();
00301 if (it.goto_child(symbol)) {
00302 double numerator =
00303 sum_nonzero_xg(it) - get_discount(it.length() - 1);
00304 it.goto_parent();
00305 double denominator = sum_nonzero_xgx(it);
00306 ret += scale * numerator / denominator;
00307 }
00308
00309 scale *= interpolation(it);
00310 }
00311
00312
00313
00314
00315 ret += scale / (m_symbol_map.size() - 1);
00316
00317 return ret;
00318 }
00319
00326 float prob_full(const Iterator &it, float *lower_prob = NULL) const
00327 {
00328 assert(m_sentence_start_id >= 0);
00329 assert(it.length() > 0);
00330 if (it.symbol() == m_sentence_start_id)
00331 throw bit::invalid_argument(
00332 "bit::Kneser::prob_full_ikn(): called for sentence start");
00333
00334 unsigned int len = it.length();
00335 double numerator = sum_nonzero_xg(it) - get_discount(len - 1);
00336 double denominator = sum_nonzero_xgx(it.parent());
00337 double lower = prob_lower(it);
00338 if (lower_prob != NULL)
00339 *lower_prob = lower;
00340 return numerator / denominator + lower;
00341 }
00342
00349 float prob_abs_lower(Iterator it) const
00350 {
00351 assert(it.length() > 0);
00352
00353 double ret = 0;
00354 double scale = 1;
00355 u32 symbol = it.symbol();
00356 it.goto_parent();
00357 scale *= interpolation_abs(it);
00358
00359 while (!it.is_root()) {
00360 it.goto_backoff_full();
00361 if (it.goto_child(symbol)) {
00362 double numerator =
00363 get_count(it) - get_discount(it.length() - 1);
00364 it.goto_parent();
00365 double denominator = sum_gx(it);
00366 ret += scale * numerator / denominator;
00367 }
00368
00369 scale *= interpolation_abs(it);
00370 }
00371
00372
00373
00374
00375 ret += scale / (m_symbol_map.size() - 1);
00376
00377 return ret;
00378 }
00379
00386 float prob_abs_full(const Iterator &it, float *lower_prob = NULL) const
00387 {
00388 assert(m_sentence_start_id >= 0);
00389 assert(it.length() > 0);
00390 if (it.symbol() == m_sentence_start_id)
00391 throw bit::invalid_argument(
00392 "bit::Kneser::prob_abs_full(): called for sentence start");
00393
00394 unsigned int len = it.length();
00395 double numerator = get_count(it) - get_discount(len - 1);
00396 double denominator = sum_gx(it.parent());
00397 double lower = prob_abs_lower(it);
00398 if (lower_prob != NULL)
00399 *lower_prob = lower;
00400 return numerator / denominator + lower;
00401 }
00402
00404 bool is_pruned(const Iterator &it) const
00405 {
00406 return get_value(m_pruned, it) > 0;
00407 }
00408
00415 u32 get_count(const Iterator &it) const
00416 {
00417 return get_value(m_counts, it);
00418 }
00419
00420 u32 sum_gx(const Iterator &it) const
00421 {
00422 if (it.is_root())
00423 return m_sum_gx0;
00424 return get_value(m_sum_gx, it);
00425 }
00426
00427 u32 sum_nonzero_xg(const Iterator &it) const
00428 {
00429 return get_value(m_sum_nonzero_xg, it);
00430 }
00431
00432 u32 sum_nonzero_gx(const Iterator &it) const
00433 {
00434 if (it.is_root())
00435 return m_sum_nonzero_gx0;
00436 return get_value(m_sum_nonzero_gx, it);
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446 }
00447
00448 u32 sum_nonzero_xgx(const Iterator &it) const
00449 {
00450 if (it.is_root())
00451 return m_sum_nonzero_xgx0;
00452 return get_value(m_sum_nonzero_xgx, it);
00453
00454
00455
00456
00457
00458
00459
00460
00461 }
00462
00463 float get_beta_numerator(const Iterator &it) const
00464 {
00465 if (is_pruned(it))
00466 return 0;
00467
00468 int term1 = get_count(it) - get_value(m_sum_xg_not_pruned, it);
00469
00470 float term2;
00471 if (1) {
00472 term2 = get_beta_discount(it.length()) *
00473 get_value(m_sum_nonzero_xg_not_pruned, it);
00474 }
00475 else {
00476 term2 = get_value(m_sum_nonzero_xg_not_pruned, it);
00477 static bool printed = false;
00478 if (!printed) {
00479 fprintf(stderr, "\n"
00480 "**************************************************\n"
00481 "* WARNING: using d = 1 in beta numerator\n"
00482 "*\n");
00483 }
00484 printed = true;
00485 }
00486
00487 if (term1 < 0)
00488 throw std::out_of_range(
00489 str::fmt(256, "N(h,w) - Sum_v N(v,h,w) negative: %d\n", term1));
00490
00491 return term1 + term2;
00492 }
00493
00495 float get_beta_denominator(const Iterator &it) const
00496 {
00497 if (it.is_root())
00498 return m_beta_denominator0;
00499 if (it.length() > m_beta_denominator.size())
00500 return 1;
00501 return get_value(m_beta_denominator, it);
00502 }
00503
00504 float get_beta_interpolation_numerator(const Iterator &it) const
00505 {
00506 if (it.is_root())
00507 return m_beta_interpolation_numerator0;
00508 if (it.length() > m_beta_interpolation_numerator.size())
00509 return 1;
00510 return get_value(m_beta_interpolation_numerator, it);
00511 }
00512
00516 float get_d1(const Iterator &it) const
00517 {
00518 if (it.length() < 2)
00519 throw bit::invalid_argument("bit::Kneser::get_d1(): too low order");
00520 return get_value(m_d1, it);
00521 }
00522
00526 float get_d2(const Iterator &it) const
00527 {
00528 if (it.length() < 2)
00529 throw bit::invalid_argument("bit::Kneser::get_d2(): too low order");
00530 return get_value(m_d2, it) / get_value(m_d2_norm, it);
00531 }
00532
00534 int num_active_children(Iterator it) const
00535 {
00536 int ret = 0;
00537 if (it.goto_first_child()) {
00538 do {
00539 if (!is_pruned(it) && it.symbol() != m_sentence_start_id)
00540 ret++;
00541 } while (it.goto_next_sibling());
00542 }
00543 return ret;
00544 }
00545
00550 template <class T>
00551 std::string
00552 ngram_str(const std::vector<T> &ngram) const
00553 {
00554 std::string str;
00555 for (size_t o = 0; o < ngram.size(); o++) {
00556 if (o != 0)
00557 str.append(" ");
00558 str.append(m_symbol_map.at(ngram[o]));
00559 }
00560 return str;
00561 }
00562
00565 void write_binary_counts(FILE *file) const
00566 {
00567 m_symbol_map.write(file);
00568 m_trie.write(file);
00569 util::write(file, m_counts);
00570 util::write(file, m_sum_nonzero_xg);
00571 util::write(file, m_sum_nonzero_xgx);
00572 util::write(file, m_sum_nonzero_gx);
00573 util::write(file, m_sum_gx);
00574 fwrite(&m_sum_nonzero_xgx0, sizeof(m_sum_nonzero_xgx0), 1, file);
00575 fwrite(&m_sum_nonzero_gx0, sizeof(m_sum_nonzero_gx0), 1, file);
00576 fwrite(&m_sum_gx0, sizeof(m_sum_gx0), 1, file);
00577 }
00578
00581 void read_binary_counts(FILE *file)
00582 {
00583 m_symbol_map.read(file);
00584 m_trie.read(file);
00585 util::read(file, m_counts);
00586 util::read(file, m_sum_nonzero_xg);
00587 util::read(file, m_sum_nonzero_xgx);
00588 util::read(file, m_sum_nonzero_gx);
00589 util::read(file, m_sum_gx);
00590 size_t ret = 0;
00591 ret += fread(&m_sum_nonzero_xgx0, sizeof(m_sum_nonzero_xgx0), 1, file);
00592 ret += fread(&m_sum_nonzero_gx0, sizeof(m_sum_nonzero_gx0), 1, file);
00593 ret += fread(&m_sum_gx0, sizeof(m_sum_gx0), 1, file);
00594 if (ret != 3)
00595 throw io_error("bit::Kneser::read_binary_counts() failed");
00596
00597 m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00598 m_sentence_end_id = m_symbol_map.index(m_sentence_end_str);
00599 m_num_ngrams.resize(m_counts.size());
00600 for (size_t i = 0; i < m_counts.size(); i++)
00601 m_num_ngrams.at(i) = m_counts[i].size();
00602 }
00603
00605 void write_binary_d1d2(FILE *file) const
00606 {
00607 util::write(file, m_d1);
00608 util::write(file, m_d2);
00609 util::write(file, m_d2_norm);
00610 }
00611
00613 void read_binary_d1d2(FILE *file)
00614 {
00615 util::read(file, m_d1);
00616 util::read(file, m_d2);
00617 util::read(file, m_d2_norm);
00618 }
00619
00621 void write_arpa(FILE *file) const
00622 {
00623 fprintf(file, "\\data\\\n");
00624
00625
00626
00627
00628 for (size_t o = 0; o < m_counts.size(); o++) {
00629 fprintf(file, "\n\\%zd-grams:\n", o+1);
00630 Iterator it = root();
00631 while (it.goto_next_on_level(o)) {
00632
00633 float log_prob;
00634 if (it.length() == 1 && it.symbol() == m_sentence_start_id)
00635 log_prob = -99;
00636 else
00637 log_prob = log10(prob_full(it));
00638
00639 fprintf(file, "%g\t%s", log_prob,
00640 ngram_str(it.symbol_vec()).c_str());
00641 if (it.num_children() > 0)
00642 fprintf(file, "\t%g\n", log10(interpolation(it)));
00643 else
00644 fputs("\n", file);
00645 }
00646 }
00647 fprintf(file, "\n\\end\\\n");
00648 }
00649
00651 void write_beta_arpa(FILE *file) const
00652 {
00653 fprintf(file, "\\data\\\n");
00654
00655
00656
00657
00658 for (size_t i = 0; i < m_num_ngrams.size(); i++) {
00659 assert(m_num_ngrams[i] >= 0);
00660 if (m_num_ngrams[i] == 0)
00661 break;
00662 fprintf(file, "ngram %d=%d\n", (int)i + 1, m_num_ngrams[i]);
00663 }
00664
00665 Progress p(m_progress_skip, num_ngrams());
00666 p.set_report_string("writing arpa:");
00667
00668 for (size_t o = 0; o < m_counts.size(); o++) {
00669 fprintf(file, "\n\\%zd-grams:\n", o+1);
00670 Iterator it = root();
00671 while (it.goto_next_on_level(o)) {
00672 p.step();
00673 if (get_value(m_pruned, it) > 0)
00674 continue;
00675
00676 float log_prob;
00677 if (it.length() == 1 && it.symbol() == m_sentence_start_id)
00678 log_prob = -99;
00679 else
00680 log_prob = log10(prob_beta_full(it));
00681
00682 fprintf(file, "%g\t%s", log_prob,
00683 ngram_str(it.symbol_vec()).c_str());
00684 if (num_active_children(it) > 0)
00685 fprintf(file, "\t%g\n",
00686 log10(get_beta_interpolation_numerator(it) /
00687 get_beta_denominator(it)));
00688 else
00689 fputs("\n", file);
00690 }
00691 }
00692 fprintf(file, "\n\\end\\\n");
00693 p.finish();
00694 }
00695
00699 void reserve_orders(unsigned int orders)
00700 {
00701 m_counts.reserve(orders);
00702 m_sum_nonzero_xg.reserve(orders);
00703 m_sum_nonzero_gx.reserve(orders);
00704 m_sum_nonzero_xgx.reserve(orders);
00705 m_pruned.reserve(orders);
00706 m_d1.reserve(orders);
00707 m_d2.reserve(orders);
00708 m_d2_norm.reserve(orders);
00709 m_sum_xg_not_pruned.reserve(orders);
00710 m_sum_nonzero_xg_not_pruned.reserve(orders);
00711 m_beta_denominator.reserve(orders);
00712 m_beta_interpolation_numerator.reserve(orders);
00713 m_trie.reserve_levels(orders);
00714 }
00715
00716
00725 void read_counts(FILE *file, bool integer_symbols = false)
00726 {
00727 std::string line;
00728 std::vector<int> ngram;
00729 std::vector<std::string> fields;
00730 int count;
00731 Progress p(m_progress_skip);
00732 p.set_report_string("reading counts:");
00733 while (str::read_line(line, file, true)) {
00734 try {
00735 if (integer_symbols) {
00736 ngram = str::long_vec<int>(line);
00737 if (ngram.size() < 2)
00738 throw std::exception();
00739 count = ngram.back();
00740 ngram.pop_back();
00741 }
00742 else {
00743 fields = str::split(line, " \t", true);
00744 if (fields.size() < 2)
00745 throw std::exception();
00746 count = str::str2long(fields.back());
00747 fields.pop_back();
00748 ngram.resize(fields.size());
00749 for (size_t i = 0; i < fields.size(); i++)
00750 ngram[i] = m_symbol_map.insert(fields[i]);
00751 }
00752 }
00753 catch (std::exception &e) {
00754 throw bit::io_error(
00755 std::string("bit::Kneser::read_counts(): invalid line: ") + line);
00756 }
00757
00758 add(ngram, count);
00759 p.step();
00760 }
00761 p.finish();
00762
00763 m_num_ngrams.resize(m_counts.size());
00764 for (size_t i = 0; i < m_counts.size(); i++)
00765 m_num_ngrams.at(i) = m_counts[i].size();
00766 }
00767
00773 void compute_sums()
00774 {
00775 assert(m_sentence_start_id < 0);
00776 m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00777 m_sentence_end_id = m_symbol_map.index(m_sentence_end_str);
00778 {
00779 Iterator it = root();
00780 it.goto_child(m_sentence_start_id);
00781 set_count(it, 0);
00782 fprintf(stderr, "WARNING: setting count(%s) = 0\n",
00783 m_sentence_start_str.c_str());
00784 }
00785
00786 assert(m_sum_nonzero_xg.empty());
00787 assert(m_sum_nonzero_xgx.empty());
00788 m_sum_gx.resize(m_counts.size() - 1);
00789 m_sum_nonzero_xg.resize(m_counts.size());
00790 m_sum_nonzero_xgx.resize(m_counts.size() - 1);
00791 m_sum_nonzero_gx.resize(m_counts.size() - 1);
00792 for (size_t i = 0; i < m_sum_nonzero_xg.size(); i++)
00793 m_sum_nonzero_xg[i].resize(m_counts[i].size());
00794 for (size_t i = 0; i < m_sum_nonzero_xgx.size(); i++) {
00795 m_sum_gx[i].resize(m_counts[i].size());
00796 m_sum_nonzero_xgx[i].resize(m_counts[i].size());
00797 m_sum_nonzero_gx[i].resize(m_counts[i].size());
00798 }
00799 m_sum_nonzero_xgx0 = 0;
00800 m_sum_nonzero_gx0 = 0;
00801 m_sum_gx0 = 0;
00802
00803
00804
00805 {
00806 Progress p(m_progress_skip, num_ngrams());
00807 p.set_report_string("modified counts:");
00808 Iterator it = root();
00809 while (it.goto_next_depth_first()) {
00810 p.step();
00811
00812
00813 if (m_d1_model == 1 || m_d1_weight_model == 1) {
00814 Iterator parent(it.parent());
00815 u32 count = get_count(it);
00816 if (parent.length() == 0)
00817 m_sum_gx0 += count;
00818 else
00819 add_value(m_sum_gx, parent, count);
00820 }
00821
00822 if (it.length() < 2)
00823 continue;
00824
00825 Iterator bo_it(it);
00826 if (bo_it.goto_backoff_once())
00827 add_value(m_sum_nonzero_xg, bo_it, 1);
00828 }
00829 p.finish();
00830 }
00831
00832
00833
00834
00835 {
00836 Progress p(m_progress_skip, num_ngrams());
00837 p.set_report_string("without context:");
00838 for (size_t o = 0; o < m_counts.size(); o++) {
00839 IntVec &src_array = m_counts.at(o);
00840 IntVec &sum_nonzero_xg = m_sum_nonzero_xg.at(o);
00841 for (u64 i = 0; i < src_array.size(); i++) {
00842 p.step();
00843 if (sum_nonzero_xg.at(i) == 0)
00844 sum_nonzero_xg[i] = src_array.at(i);
00845 }
00846 }
00847 p.finish();
00848 }
00849
00850
00851
00852 {
00853 Progress p(m_progress_skip, num_ngrams());
00854 p.set_report_string("modified sum counts:");
00855 Iterator it = root();
00856 while (it.goto_next_depth_first()) {
00857 p.step();
00858 if (it.symbol() == m_sentence_start_id)
00859 continue;
00860 u32 value = sum_nonzero_xg(it);
00861 Iterator parent = it.parent();
00862 if (parent.is_root()) {
00863 m_sum_nonzero_xgx0 += value;
00864 m_sum_nonzero_gx0++;
00865 }
00866 else {
00867 add_value(m_sum_nonzero_xgx, parent, value);
00868 add_value(m_sum_nonzero_gx, parent, 1);
00869 }
00870 }
00871 p.finish();
00872 }
00873 }
00874
00876 void compute_d1()
00877 {
00878 float max_d1 = -1e30;
00879 float min_d1 = 1e30;
00880
00881 assert(m_d1.empty());
00882 m_d1.resize(m_counts.size());
00883 for (size_t i = 0; i < m_counts.size(); i++)
00884 m_d1.at(i).resize(m_counts.at(i).size());
00885
00886 Iterator it = root();
00887
00888 Progress p(m_progress_skip, num_ngrams());
00889 p.set_report_string("computing d1:");
00890 while (it.goto_next_depth_first()) {
00891 p.step();
00892 if (it.length() < 2)
00893 continue;
00894
00895 double orig = 0;
00896 float lower_prob = 0;
00897 if (m_d1_model == 0) {
00898 orig = prob_full(it, &lower_prob);
00899 }
00900 else if (m_d1_model == 1) {
00901 orig = prob_abs_full(it, &lower_prob);
00902 }
00903 else
00904 throw bit::invalid_argument(
00905 "bit::Kneser::compute_d1() invalid m_d1_model");
00906
00907 double d1 = ngram_prob(it) * log10(orig / lower_prob);
00908
00909 if (0) {
00910 fprintf(stderr, "d1: %12g %s\n", d1,
00911 ngram_str(it.symbol_vec()).c_str());
00912 }
00913
00914 if (!(d1 > 0 && d1 < 1e10)) {
00915 fprintf(stderr, "WARNING: d1 = %g for %s\n", d1,
00916 ngram_str(it.symbol_vec()).c_str());
00917 }
00918
00919 set_value(m_d1, it, d1);
00920 if (d1 < min_d1)
00921 min_d1 = d1;
00922 if (d1 > max_d1)
00923 max_d1 = d1;
00924 }
00925 p.finish();
00926
00927 fprintf(stderr, "min_d1 = %g\n", min_d1);
00928 fprintf(stderr, "max_d1 = %g\n", max_d1);
00929 }
00930
00933 void compute_d2_full()
00934 {
00935 assert(m_d2.empty());
00936 assert(!m_d1.empty());
00937 assert(m_d2_norm.empty());
00938 m_d2.resize(m_counts.size());
00939 m_d2_norm.resize(m_counts.size());
00940 for (size_t i = 0; i < m_counts.size(); i++) {
00941 m_d2.at(i).resize(m_counts.at(i).size());
00942 m_d2_norm.at(i).resize(m_counts.at(i).size());
00943 }
00944
00945
00946
00947 Iterator it = root();
00948 Progress p(m_progress_skip, num_ngrams());
00949 p.set_report_string("full d2:");
00950 while (it.goto_next_depth_first_post()) {
00951 p.step();
00952 if (it.length() < 2)
00953 continue;
00954
00955 D2Norm pair(get_d1(it), 1);
00956 Iterator child_it(it);
00957 if (child_it.goto_first_child()) {
00958 do {
00959 float d2 = get_value(m_d2, child_it);
00960 u32 norm = get_value(m_d2_norm, child_it);
00961 pair.add(d2, norm);
00962 } while (child_it.goto_next_sibling());
00963 }
00964
00965 set_value(m_d2, it, pair.d2);
00966 set_value(m_d2_norm, it, pair.norm);
00967
00968 if (0) {
00969 fprintf(stderr, "d2: %12g %s\n", pair.d2 / pair.norm,
00970 ngram_str(it.symbol_vec()).c_str());
00971 }
00972 }
00973 p.finish();
00974 }
00975
00980 void compute_d2_trick()
00981 {
00982 fprintf(stderr, "WARNING: using erroneous d2 measure\n");
00983
00984 assert(m_d2.empty());
00985 assert(!m_d1.empty());
00986 assert(m_d2_norm.empty());
00987 m_d2.resize(m_counts.size());
00988 m_d2_norm.resize(m_counts.size());
00989 for (size_t i = 0; i < m_counts.size(); i++)
00990 m_d2.at(i).resize(m_counts.at(i).size());
00991
00992
00993
00994 std::vector<D2Norm> child_pairs;
00995 Iterator it = root();
00996 Progress p(m_progress_skip, num_ngrams());
00997 p.set_report_string("computing d2:");
00998 while (it.goto_next_depth_first_post()) {
00999 p.step();
01000 if (it.length() < 2)
01001 continue;
01002
01003 D2Norm pair(get_d1(it), 1);
01004 Iterator child_it(it);
01005 child_pairs.clear();
01006 if (child_it.goto_first_child()) {
01007 do {
01008 float d2 = get_value(m_d2, child_it);
01009 u32 norm = get_value(m_d2_norm, child_it);
01010 pair.add(d2, norm);
01011 child_pairs.push_back(D2Norm(d2, norm));
01012 } while (child_it.goto_next_sibling());
01013 std::sort(child_pairs.begin(), child_pairs.end());
01014 }
01015
01016
01017
01018
01019 for (size_t i = 0; i < child_pairs.size(); i++) {
01020 if (pair < child_pairs[i])
01021 break;
01022 pair.add(-child_pairs[i].d2, -child_pairs[i].norm);
01023 }
01024
01025 set_value(m_d2, it, pair.d2);
01026 set_value(m_d2_norm, it, pair.norm);
01027 }
01028 p.finish();
01029 }
01030
01037 void prune_ngram(Iterator it)
01038 {
01039 if (it.length() < 2)
01040 throw bit::invalid_argument(
01041 "bit::Kneser::prune_ngram() ngram shorter than 2-gram");
01042
01043 if (get_value(m_pruned, it) > 0)
01044 throw bit::invalid_argument(
01045 "bit::Kneser::prune_ngram() ngram pruned already");
01046 set_value(m_pruned, it, 1);
01047 m_num_ngrams.at(it.length() - 1)--;
01048
01049 float d2 = get_value(m_d2, it);
01050 int d2_norm = get_value(m_d2_norm, it);
01051
01052 if (0) {
01053 fprintf(stderr, "pruned %12g %12g / %d %s\n",
01054 get_value(m_d1, it),
01055 d2, d2_norm,
01056 ngram_str(it.symbol_vec()).c_str());
01057 }
01058
01059
01060 {
01061 Iterator child_it(it);
01062 unsigned int len = it.length();
01063 while (child_it.goto_next_depth_first()) {
01064 if (child_it.length() <= len)
01065 break;
01066 if (get_value(m_pruned, child_it) > 0)
01067 continue;
01068 if (0) {
01069 fprintf(stderr, "pruned %12g %12g / %d %s\n",
01070 get_value(m_d1, child_it),
01071 get_value(m_d2, child_it),
01072 get_value(m_d2_norm, child_it),
01073 ngram_str(child_it.symbol_vec()).c_str());
01074 }
01075 set_value(m_pruned, child_it, 1);
01076 m_num_ngrams.at(child_it.length() - 1)--;
01077 }
01078 }
01079
01080
01081 while (1) {
01082 it.goto_parent();
01083 if (it.length() == 1)
01084 break;
01085 add_value(m_d2, it, -d2);
01086 add_value(m_d2_norm, it, -d2_norm);
01087 }
01088 }
01089
01093 void prune_threshold(float threshold)
01094 {
01095 assert(!m_d1.empty());
01096 assert(!m_d2.empty());
01097 assert(m_pruned.empty());
01098 m_pruned.resize(m_counts.size());
01099
01100 Progress p(m_progress_skip, num_ngrams());
01101 p.set_report_string("pruning:");
01102 Iterator it = root();
01103 while (it.goto_next_depth_first_post()) {
01104 p.step();
01105 if (it.length() < 2)
01106 continue;
01107 if (get_d2(it) < threshold)
01108 prune_ngram(it);
01109 }
01110 p.finish();
01111 fprintf(stderr, "%lld ngrams left\n", num_active_ngrams());
01112 }
01113
01122 void prune(unsigned int ngrams)
01123 {
01124 assert(!m_d1.empty());
01125 assert(!m_d2.empty());
01126 assert(m_pruned.empty());
01127 m_pruned.resize(m_counts.size());
01128
01129 size_t num_ngrams = 0;
01130 for (size_t i = 0; i < m_counts.size(); i++)
01131 num_ngrams += m_counts.at(i).size();
01132 std::vector<OrderIndex> vec;
01133 vec.reserve(num_ngrams);
01134
01135 Iterator it = root();
01136 while (it.goto_next_depth_first_post()) {
01137 assert(it.length() > 0);
01138 if (it.length() < 2)
01139 continue;
01140 vec.push_back(OrderIndex(it.length() - 1, it.symbol_index()));
01141 }
01142
01143 if (ngrams > vec.size())
01144 throw bit::invalid_argument(
01145 "bit::Kneser::prune() trying to prune too many ngrams");
01146
01147 Progress p(0, 2);
01148 p.set_report_string("sorting:");
01149 p.step();
01150 std::partial_sort(vec.begin(), vec.begin() + ngrams, vec.end(),
01151 PruneCompare(this));
01152 p.step();
01153 p.finish();
01154
01155 for (size_t i = 0; i < ngrams; i++) {
01156 m_num_ngrams.at(vec[i].order)--;
01157 m_pruned.at(vec[i].order).set_grow_widen(vec[i].index, 1);
01158 }
01159
01160
01161
01162
01163
01164
01165
01166
01167
01168
01169 it = root();
01170 while (it.goto_next_depth_first()) {
01171 if (it.length() < 2)
01172 continue;
01173 if (get_value(m_pruned, it.parent()) > 0 &&
01174 get_value(m_pruned, it) == 0)
01175 {
01176 unsigned int order = it.length() - 1;
01177 m_num_ngrams.at(order)--;
01178 set_value(m_pruned, it, 1);
01179
01180
01181
01182 }
01183 }
01184
01185 }
01186
01190 void compute_beta_numerator_terms()
01191 {
01192 assert(!m_discounts.empty());
01193 m_beta_discounts = m_discounts;
01194 for (size_t i = m_beta_discounts.size() - 1; i > 0; i--)
01195 m_beta_discounts[i - 1] *= m_beta_discounts[i];
01196
01197 assert(m_sentence_start_id >= 0);
01198 assert(m_sum_xg_not_pruned.empty());
01199 assert(m_sum_nonzero_xg_not_pruned.empty());
01200 m_sum_xg_not_pruned.resize(m_counts.size() - 1);
01201 m_sum_nonzero_xg_not_pruned.resize(m_counts.size() - 1);
01202 for (size_t i = 0; i < m_counts.size() - 1; i++) {
01203 m_sum_xg_not_pruned.at(i).resize(m_counts.at(i).size());
01204 m_sum_nonzero_xg_not_pruned.at(i).resize(m_counts.at(i).size());
01205 }
01206
01207 Iterator it = root();
01208 Progress p(m_progress_skip, num_ngrams());
01209 p.set_report_string("beta numerator terms:");
01210 while (it.goto_next_depth_first()) {
01211 p.step();
01212 if (get_value(m_pruned, it) > 0)
01213 continue;
01214 if (it.length() < 2)
01215 continue;
01216
01217 Iterator bo_it(it);
01218 if (bo_it.goto_backoff_once()) {
01219 add_value(m_sum_xg_not_pruned, bo_it, get_count(it));
01220 add_value(m_sum_nonzero_xg_not_pruned, bo_it, 1);
01221 }
01222 }
01223 p.finish();
01224 }
01225
01229 void compute_beta_denominator()
01230 {
01231 assert(m_beta_denominator.empty());
01232 m_beta_denominator.resize(m_counts.size() - 1);
01233 m_beta_denominator0 = 0;
01234 for (size_t i = 0; i < m_beta_denominator.size(); i++)
01235 m_beta_denominator.at(i).resize(m_counts.at(i).size());
01236
01237 Iterator it = root();
01238 Progress p(m_progress_skip, num_ngrams());
01239 p.set_report_string("beta denominator:");
01240 while (it.goto_next_depth_first()) {
01241 p.step();
01242
01243 if (it.length() <= m_beta_denominator.size()) {
01244 int pruned_counts;
01245 compute_active_children(it, &pruned_counts);
01246 set_value(m_beta_denominator, it, pruned_counts);
01247 }
01248
01249 if (it.symbol() == m_sentence_start_id)
01250 continue;
01251
01252 Iterator parent_it(it);
01253 parent_it.goto_parent();
01254
01255
01256
01257
01258 float numerator = get_beta_numerator(it);
01259 if (parent_it.is_root())
01260 m_beta_denominator0 += numerator;
01261 else
01262 add_value(m_beta_denominator, parent_it, numerator);
01263 }
01264 p.finish();
01265 }
01266
01268 int compute_active_children(Iterator it, int *pruned_counts = NULL)
01269 {
01270 int num_active_children = 0;
01271 if (pruned_counts != NULL)
01272 (*pruned_counts) = 0;
01273 if (it.goto_first_child()) {
01274 do {
01275 if (it.symbol() == m_sentence_start_id)
01276 continue;
01277 if (is_pruned(it)) {
01278 if (pruned_counts != NULL)
01279 (*pruned_counts) += get_count(it);
01280 continue;
01281 }
01282
01283 num_active_children++;
01284 } while (it.goto_next_sibling());
01285 }
01286 return num_active_children;
01287 }
01288
01293 void compute_beta_interpolation_numerator()
01294 {
01295 assert(m_beta_interpolation_numerator.empty());
01296 m_beta_interpolation_numerator.resize(m_counts.size() - 1);
01297 for (size_t i = 0; i < m_beta_interpolation_numerator.size(); i++)
01298 m_beta_interpolation_numerator.at(i).resize(m_counts.at(i).size());
01299
01300 m_beta_interpolation_numerator0 = 0;
01301 Progress p(m_progress_skip, num_ngrams());
01302 p.set_report_string("beta interpolation numerator:");
01303 Iterator it = root();
01304 do {
01305 p.step();
01306
01307 if (it.length() > m_beta_interpolation_numerator.size())
01308 continue;
01309
01310 if (!it.is_root()) {
01311 if (is_pruned(it)) {
01312 set_value(m_beta_interpolation_numerator, it, 1);
01313 continue;
01314 }
01315 }
01316
01317 int pruned_counts = 0;
01318 int num_active_children = compute_active_children(it, &pruned_counts);
01319 double interpolation = get_beta_discount(it.length()) *
01320 num_active_children + pruned_counts;
01321
01322 if (it.is_root())
01323 m_beta_interpolation_numerator0 = interpolation;
01324 else
01325 set_value(m_beta_interpolation_numerator, it, interpolation);
01326 } while (it.goto_next_depth_first());
01327 p.finish();
01328 }
01329
01336 Iterator add(const std::vector<int> &vec, int value)
01337 {
01338 Iterator it = m_trie.insert(vec);
01339 add_value(m_counts, it, value);
01340 return it;
01341 }
01342
01350 void set_discount(unsigned int order, float value)
01351 {
01352 m_discounts.resize(order + 1);
01353 m_discounts.at(order) = value;
01354 }
01355
01361 float get_discount(unsigned int order) const
01362 {
01363 if (m_discounts.empty())
01364 throw bit::invalid_call(
01365 "bit::Kneser::get_discount() discount not set");
01366 if (order >= m_discounts.size())
01367 return m_discounts.back();
01368 return m_discounts.at(order);
01369 }
01370
01376 float get_beta_discount(unsigned int order) const
01377 {
01378 if (m_beta_discounts.empty())
01379 throw bit::invalid_call(
01380 "bit::Kneser::get_beta_discount() discount not set");
01381 if (order >= m_beta_discounts.size())
01382 return m_beta_discounts.back();
01383 return m_beta_discounts.at(order);
01384 }
01385
01386 float interpolation(const Iterator &it) const
01387 {
01388 assert(it.length() == 0 || it.symbol() != m_sentence_end_id);
01389 double nominator = sum_nonzero_gx(it) * get_discount(it.length());
01390 double denominator = sum_nonzero_xgx(it);
01391 return nominator / denominator;
01392 }
01393
01394 float interpolation_abs(const Iterator &it) const
01395 {
01396 assert(it.length() == 0 || it.symbol() != m_sentence_end_id);
01397 double nominator = sum_nonzero_gx(it) * get_discount(it.length());
01398 double denominator = sum_gx(it);
01399 return nominator / denominator;
01400 }
01401
01406 void set_count(const Iterator &it, u32 value)
01407 {
01408 set_value(m_counts, it, value);
01409 }
01410
01411 std::string debug_sum_nonzero_xg_str()
01412 {
01413 std::string str;
01414 Iterator it = root();
01415 while (it.goto_next_depth_first()) {
01416 str.append(ngram_str(it.symbol_vec()));
01417 str.append("\t");
01418 str.append(str::fmt(64, "%d\n", sum_nonzero_xg(it)));
01419 }
01420 return str;
01421 }
01422
01423 void debug_write_counts(FILE *file)
01424 {
01425 fprintf(file, "sum_nonzero_xg_not_pruned:\n");
01426 Iterator it = root();
01427 while (it.goto_next_depth_first()) {
01428 fprintf(file, "%s\t bdenom=%.2f bnum=%.2f\n",
01429 ngram_str(it.symbol_vec()).c_str(),
01430 it.length() < 3 ? get_beta_denominator(it) : -1,
01431 get_beta_numerator(it)
01432 );
01433
01434 }
01435 }
01436
01437
01438 private:
01439
01440 float get_value(const std::vector<FloatVec> &arrays, const Iterator &it) const
01441 {
01442 unsigned int len = it.length();
01443 u32 index = it.symbol_index();
01444 return arrays.at(len - 1).at(index);
01445 }
01446
01447 int get_value(const std::vector<IntVec> &arrays, const Iterator &it) const
01448 {
01449 unsigned int len = it.length();
01450 if (len > arrays.size())
01451 return 0;
01452 u32 index = it.symbol_index();
01453 const IntVec &vec = arrays.at(len - 1);
01454 if (index >= vec.size())
01455 return 0;
01456 return vec[index];
01457 }
01458
01459 u32 get_value(const std::vector<Array> &arrays, const Iterator &it) const
01460 {
01461 unsigned int len = it.length();
01462 u32 index = it.symbol_index();
01463 if (len > arrays.size())
01464 return 0;
01465 const Array &array = arrays.at(len - 1);
01466 if (index >= array.num_elems())
01467 return 0;
01468 return array.get(index);
01469 }
01470
01471 void set_value(std::vector<FloatVec> &arrays, const Iterator &it,
01472 float value)
01473 {
01474 unsigned int len = it.length();
01475 u32 index = it.symbol_index();
01476 arrays.at(len - 1).at(index) = value;
01477 }
01478
01479 void set_value(std::vector<IntVec> &arrays, const Iterator &it,
01480 int value)
01481 {
01482 unsigned int len = it.length();
01483 if (len > arrays.size())
01484 arrays.resize(len);
01485 u32 index = it.symbol_index();
01486 IntVec &vec = arrays.at(len - 1);
01487 while (vec.size() <= index)
01488 vec.push_back(0);
01489 vec[index] = value;
01490 }
01491
01492 void set_value(std::vector<Array> &arrays, const Iterator &it, u32 value)
01493 {
01494 unsigned int len = it.length();
01495 u32 index = it.symbol_index();
01496 if (len > arrays.size())
01497 arrays.resize(len);
01498 arrays.at(len - 1).set_grow_widen(index, value);
01499 }
01500
01501 void add_value(std::vector<Array> &arrays, const Iterator &it, u32 value)
01502 {
01503 set_value(arrays, it, get_value(arrays, it) + value);
01504 }
01505
01506 void add_value(std::vector<FloatVec> &arrays, const Iterator &it,
01507 float value)
01508 {
01509 set_value(arrays, it, get_value(arrays, it) + value);
01510 }
01511
01512 void add_value(std::vector<IntVec> &arrays, const Iterator &it,
01513 int value)
01514 {
01515 set_value(arrays, it, get_value(arrays, it) + value);
01516 }
01517
01518 void sub_value(std::vector<Array> &arrays, const Iterator &it, u32 value)
01519 {
01520 u32 old_value = get_value(arrays, it);
01521 if (old_value < value)
01522 throw bit::invalid_argument("bit::Kneser::sub_value() underflow");
01523 set_value(arrays, it, get_value(arrays, it) - value);
01524 }
01525
01528
01532 int m_d1_weight_model;
01533
01536 int m_d1_model;
01537
01539
01541 int m_progress_skip;
01542
01544 Trie m_trie;
01545
01547 std::vector<int> m_num_ngrams;
01548
01550 std::vector<IntVec> m_counts;
01551
01553 std::vector<IntVec> m_sum_gx;
01554
01556 int m_sum_gx0;
01557
01559 std::vector<IntVec> m_sum_nonzero_xg;
01560
01562 std::vector<IntVec> m_sum_nonzero_gx;
01563
01565 int m_sum_nonzero_gx0;
01566
01568 std::vector<IntVec> m_sum_nonzero_xgx;
01569
01571 int m_sum_nonzero_xgx0;
01572
01574 std::vector<float> m_discounts;
01575
01579 std::vector<float> m_beta_discounts;
01580
01582 SymbolMap m_symbol_map;
01583
01585 std::string m_sentence_start_str;
01586
01588 std::string m_sentence_end_str;
01589
01591 int m_sentence_start_id;
01592
01594 int m_sentence_end_id;
01595
01598
01599 std::vector<Array> m_pruned;
01600 std::vector<FloatVec> m_d1;
01601 std::vector<FloatVec> m_d2;
01602 std::vector<IntVec> m_d2_norm;
01603 std::vector<IntVec> m_sum_xg_not_pruned;
01604 std::vector<IntVec> m_sum_nonzero_xg_not_pruned;
01605
01608 std::vector<FloatVec> m_beta_denominator;
01609
01612 float m_beta_denominator0;
01613
01616 std::vector<FloatVec> m_beta_interpolation_numerator;
01617
01619 float m_beta_interpolation_numerator0;
01620
01621 struct OrderIndex {
01622 OrderIndex(unsigned int order = 0, u32 index = 0)
01623 : order(order), index(index) { }
01624 unsigned int order;
01625 u32 index;
01626 };
01627
01628 struct PruneCompare {
01629 const Kneser *k;
01630
01631 PruneCompare(const Kneser *k)
01632 : k(k)
01633 {
01634 }
01635
01636 bool operator()(const OrderIndex &a, const OrderIndex &b) const
01637 {
01638 float d2_a = k->m_d2.at(a.order).at(a.index);
01639 int d2_a_norm = k->m_d2_norm.at(a.order).at(a.index);
01640 float d2_b = k->m_d2.at(b.order).at(b.index);
01641 int d2_b_norm = k->m_d2_norm.at(b.order).at(b.index);
01642 return (d2_a / d2_a_norm) < (d2_b / d2_b_norm);
01643 }
01644 };
01645
01646 struct D2Norm {
01647 float d2;
01648 int norm;
01649 float value;
01650
01651 D2Norm() : d2(0), norm(0), value(0)
01652 {
01653 }
01654
01655 D2Norm(float d2, int norm) : d2(d2), norm(norm), value(d2 / norm)
01656 {
01657 }
01658
01659 void add(float d2, int norm)
01660 {
01661 this->d2 += d2;
01662 this->norm += norm;
01663 value = this->d2 / this->norm;
01664 }
01665
01666 bool operator<(const D2Norm &p) const
01667 {
01668 return value < p.value;
01669 }
01670 };
01671
01673 };
01674
01675 }
01676
01677 #endif