00001 #ifndef SIMPLEKNESER_HH
00002 #define SIMPLEKNESER_HH
00003
00004 #include "lm/SymbolMap.hh"
00005 #include "str/str.hh"
00006 #include <map>
00007 #include <vector>
00008 #include "util/Progress.hh"
00009 #include "bit/exceptions.hh"
00010
00011 namespace bit {
00012
00013 class SimpleKneser {
00014 public:
00015
00016 typedef SymbolMap<std::string,int> SymbolMap;
00017 typedef std::vector<int> Ngram;
00018 typedef std::vector<std::string> StrNgram;
00019 typedef std::vector<int> IntVec;
00020 typedef std::vector<float> FloatVec;
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 struct FloatInt {
00036 FloatInt() : f(0), i(0) { }
00037 FloatInt(float f, int i) : f(f), i(i) { }
00038 float f;
00039 int i;
00040
00041 void add(const FloatInt &value)
00042 {
00043 f += value.f;
00044 i += value.i;
00045 assert(i > 0);
00046 }
00047
00048 void sub(const FloatInt &value)
00049 {
00050 f -= value.f;
00051 i -= value.i;
00052 assert(i > 0);
00053 }
00054
00055 float value() const
00056 {
00057 return f / i;
00058 }
00059
00060 };
00061
00062
00063 typedef std::map<Ngram, int> IntMap;
00064 typedef std::map<Ngram, float> FloatMap;
00065 typedef std::map<Ngram, FloatInt> FloatIntMap;
00066
00067 SimpleKneser()
00068 {
00069 init();
00070 }
00071
00072 SimpleKneser(const std::string &str)
00073 {
00074 init();
00075 StrNgram str_ngram = str::split(str, " \t", true);
00076 for (size_t i = 0; i < str_ngram.size(); i++)
00077 m_symbol_map.insert(str_ngram[i]);
00078 m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00079 m_num_events = m_symbol_map.size() - 1;
00080 }
00081
00082 void init()
00083 {
00084 m_progress_skip = 7351;
00085 m_sentence_start_str = "<s>";
00086 m_sentence_start_id = -1;
00087 }
00088
00089 void set_discounts(const std::string &str)
00090 {
00091 m_discounts = str::float_vec(str);
00092 if (m_discounts.empty())
00093 throw bit::invalid_argument(
00094 "bit::SimpleKneser::set_discounts() empty vector");
00095
00096 m_discounts.insert(m_discounts.begin(), -1e30);
00097
00098 m_beta_discounts = m_discounts;
00099 for (size_t i = m_beta_discounts.size() - 1; i > 0; i--)
00100 m_beta_discounts[i-1] *= m_beta_discounts[i];
00101 }
00102
00103 float get_discount(unsigned int order) const
00104 {
00105 if (m_discounts.empty())
00106 throw bit::invalid_call(
00107 "bit::SimpleKneser::get_discount() discounts not set");
00108 if (order == 0)
00109 throw bit::invalid_argument(
00110 "bit::SimpleKneser::get_discount() trying to get zero order");
00111 if (order >= m_discounts.size())
00112 return m_discounts.back();
00113 return m_discounts.at(order);
00114 }
00115
00116 float get_beta_discount(unsigned int order) const
00117 {
00118 if (m_beta_discounts.empty())
00119 throw bit::invalid_call(
00120 "bit::SimpleKneser::get_beta_discount() beta_discounts not set");
00121 if (order == 0)
00122 throw bit::invalid_argument(
00123 "bit::SimpleKneser::get_beta_discount() trying to get zero order");
00124 if (order >= m_beta_discounts.size())
00125 return m_beta_discounts.back();
00126 return m_beta_discounts.at(order);
00127 }
00128
00129 Ngram ngram(const std::string &str) const
00130 {
00131 StrNgram str_ngram = str::split(str, " \t", true);
00132 Ngram ngram(str_ngram.size());
00133 for (size_t i = 0; i < str_ngram.size(); i++)
00134 ngram[i] = m_symbol_map.index(str_ngram[i]);
00135 return ngram;
00136 }
00137
00138 int get_count(const Ngram &ngram) const
00139 {
00140 IntMap::const_iterator it = m_counts.find(ngram);
00141 if (it == m_counts.end())
00142 throw bit::invalid_argument(
00143 "bit::SimpleKneser::get_count() ngram not found");
00144 return it->second;
00145 }
00146
00147 int get_sum_nonzero_xg(const Ngram &ngram) const
00148 {
00149 IntMap::const_iterator it = m_sum_nonzero_xg.find(ngram);
00150 if (it == m_sum_nonzero_xg.end())
00151 return -1;
00152 return it->second;
00153 }
00154
00155 int get_sum_nonzero_xgx(const Ngram &ngram) const
00156 {
00157 IntMap::const_iterator it = m_sum_nonzero_xgx.find(ngram);
00158 if (it == m_sum_nonzero_xgx.end())
00159 throw bit::invalid_argument(
00160 "bit::SimpleKneser::get_sum_nonzero_xgx() ngram not found");
00161 return it->second;
00162 }
00163
00164 int get_sum_nonzero_gx(const Ngram &ngram) const
00165 {
00166 IntMap::const_iterator it = m_sum_nonzero_gx.find(ngram);
00167 if (it == m_sum_nonzero_gx.end())
00168 throw bit::invalid_argument(
00169 "bit::SimpleKneser::get_sum_nonzero_gx() ngram not found");
00170 assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00171 return it->second;
00172 }
00173
00174 int get_sum_xg_not_pruned(const Ngram &ngram) const
00175 {
00176 if (!ngram.empty() && ngram.back() == m_sentence_start_id)
00177 throw bit::invalid_argument(
00178 "bit::SimpleKneser::get_sum_xg_not_pruned() sentence start");
00179 IntMap::const_iterator it = m_sum_xg_not_pruned.find(ngram);
00180 if (it == m_sum_xg_not_pruned.end())
00181 return 0;
00182 return it->second;
00183 }
00184
00185 int get_sum_nonzero_xg_not_pruned(const Ngram &ngram) const
00186 {
00187 if (!ngram.empty() && ngram.back() == m_sentence_start_id)
00188 throw bit::invalid_argument(
00189 "bit::SimpleKneser::get_sum_nonzero_xg_not_pruned() sentence start");
00190 IntMap::const_iterator it = m_sum_nonzero_xg_not_pruned.find(ngram);
00191 if (it == m_sum_nonzero_xg_not_pruned.end())
00192 return 0;
00193 return it->second;
00194 }
00195
00196 int get_sum_nonzero_gx_not_pruned(const Ngram &ngram) const
00197 {
00198 IntMap::const_iterator it = m_sum_nonzero_gx_not_pruned.find(ngram);
00199 if (it == m_sum_nonzero_gx_not_pruned.end())
00200 return 0;
00201 return it->second;
00202 }
00203
00204 float get_d1(const Ngram &ngram) const
00205 {
00206 FloatMap::const_iterator it = m_d1.find(ngram);
00207 if (it == m_d1.end())
00208 throw bit::invalid_argument(
00209 "bit::SimpleKneser::get_d1() ngram not found");
00210 assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00211 return it->second;
00212 }
00213
00214 float get_d2(const Ngram &ngram) const
00215 {
00216 FloatInt value = get_d2_pair(ngram);
00217 return value.f / value.i;
00218 }
00219
00220 const FloatInt &get_d2_pair(const Ngram &ngram) const
00221 {
00222 FloatIntMap::const_iterator it = m_d2.find(ngram);
00223 if (it == m_d2.end())
00224 throw bit::invalid_argument(
00225 "bit::SimpleKneser::get_d2_pair() ngram not found");
00226 assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00227 return it->second;
00228 }
00229
00230 FloatInt &get_d2_pair(const Ngram &ngram)
00231 {
00232 FloatIntMap::iterator it = m_d2.find(ngram);
00233 if (it == m_d2.end())
00234 throw bit::invalid_argument(
00235 "bit::SimpleKneser::get_d2_pair() ngram not found");
00236 assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00237 return it->second;
00238 }
00239
00240 float get_beta_numerator(const Ngram &ngram) const
00241 {
00242 if (is_pruned(ngram))
00243 return 0;
00244 return get_count(ngram) - get_sum_xg_not_pruned(ngram) +
00245 get_beta_discount(ngram.size() + 1) *
00246 get_sum_nonzero_xg_not_pruned(ngram);
00247 }
00248
00249 float get_beta_denominator(const Ngram &ngram) const
00250 {
00251 FloatMap::const_iterator it = m_beta_denominator.find(ngram);
00252 if (it == m_beta_denominator.end())
00253 throw bit::invalid_argument(
00254 "bit::SimpleKneser::get_beta_denominator() ngram not found");
00255 return it->second;
00256 }
00257
00258 void add_count(const Ngram &ngram, int count)
00259 {
00260 m_counts[ngram] += count;
00261 }
00262
00263 void read_counts(FILE *file)
00264 {
00265 std::string line;
00266 StrNgram str_ngram;
00267 Progress p(m_progress_skip);
00268 p.set_report_string("reading counts:");
00269 while (str::read_line(line, file, true)) {
00270 str_ngram = str::split(line, " \t", true);
00271 if (str_ngram.size() < 2)
00272 throw bit::io_error(
00273 "bit::SimpleKneser::read_counts(): invalid line");
00274 int count = str::str2long(str_ngram.back());
00275 str_ngram.pop_back();
00276
00277 Ngram ngram(str_ngram.size());
00278 for (size_t i = 0; i < str_ngram.size(); i++)
00279 ngram[i] = m_symbol_map.insert(str_ngram[i]);
00280
00281 m_counts[ngram] += count;
00282 p.step();
00283 }
00284 p.finish();
00285
00286 assert(m_sentence_start_id < 0);
00287 m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00288 m_num_events = m_symbol_map.size() - 1;
00289 }
00290
00291 void compute_modified_counts()
00292 {
00293 assert(!m_counts.empty());
00294 Progress p(m_progress_skip, m_counts.size() * 3);
00295 p.set_report_string("mod counts:");
00296
00297 for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00298 {
00299 p.step();
00300 if (it->first.back() == m_sentence_start_id)
00301 continue;
00302 if (it->first.size() > 1) {
00303 m_sum_nonzero_xg[backoff(it->first)]++;
00304 m_sum_xg_not_pruned[backoff(it->first)] += it->second;
00305 m_sum_nonzero_xg_not_pruned[backoff(it->first)]++;
00306 }
00307 m_sum_nonzero_gx[parent(it->first)]++;
00308 }
00309
00310 for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00311 {
00312 p.step();
00313 if (m_sum_nonzero_xg[it->first] == 0)
00314 m_sum_nonzero_xg[it->first] = it->second;
00315 }
00316
00317 for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00318 {
00319 p.step();
00320 if (it->first.back() == m_sentence_start_id)
00321 continue;
00322 m_sum_nonzero_xgx[parent(it->first)] += get_sum_nonzero_xg(it->first);
00323 }
00324
00325 p.finish();
00326
00327 m_sum_nonzero_gx_not_pruned = m_sum_nonzero_gx;
00328 }
00329
00330 float ngram_prob(Ngram ngram) const
00331 {
00332 double ret = 1;
00333 while (!ngram.empty()) {
00334 if (ngram.size() == 1 && ngram.back() == m_sentence_start_id)
00335 break;
00336 ret *= prob_kn(ngram);
00337 ngram.pop_back();
00338 }
00339 return ret;
00340 }
00341
00342 void compute_d1()
00343 {
00344 double min_d1 = 1e100;
00345 double max_d1 = -1e100;
00346
00347 assert(!m_sum_nonzero_xg.empty());
00348 Progress p(m_progress_skip, m_counts.size());
00349 p.set_report_string("d1:");
00350
00351 for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00352 {
00353 p.step();
00354 if (it->first.back() == m_sentence_start_id)
00355 continue;
00356 double orig = prob_kn(it->first);
00357 double lower = prob_kn_lower(it->first);
00358
00359 double d1 = ngram_prob(it->first) * log10(orig / lower);
00360 m_d1[it->first] = d1;
00361
00362 if (d1 < min_d1)
00363 min_d1 = d1;
00364 if (d1 > max_d1)
00365 max_d1 = d1;
00366
00367 if (0) {
00368 fprintf(stderr, "d1 = %12g %s\n", m_d1[it->first],
00369 ngram_str(it->first).c_str());
00370 }
00371
00372 }
00373 p.finish();
00374
00375 fprintf(stderr, "min_d1 = %g\n", min_d1);
00376 fprintf(stderr, "max_d1 = %g\n", max_d1);
00377 }
00378
00379 void compute_d2()
00380 {
00381 assert(!m_d1.empty());
00382 Progress p(m_progress_skip, m_counts.size());
00383 p.set_report_string("d2:");
00384
00385 FloatMap::iterator it = m_d1.end();
00386 while (it != m_d1.begin()) {
00387 it--;
00388 p.step();
00389
00390 FloatInt value(it->second, 1);
00391 FloatInt &d2 = m_d2[it->first];
00392
00393 d2.add(value);
00394 if (it->first.size() > 1)
00395 m_d2[parent(it->first)].add(d2);
00396
00397 if (0) {
00398 fprintf(stderr, "%12g / %d %s\n", d2.f, d2.i,
00399 ngram_str(it->first).c_str());
00400 }
00401 }
00402 p.finish();
00403 }
00404
00405 void compute_beta_denominator()
00406 {
00407 fprintf(stderr, "FATAL: compute_beta_denominator(): "
00408 "denominator is computed incorrecly\n");
00409 abort();
00410
00411 assert(!m_sum_xg_not_pruned.empty());
00412 Progress p(m_progress_skip, m_counts.size());
00413 p.set_report_string("beta denominator:");
00414 m_beta_denominator.clear();
00415
00416 for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00417 {
00418 p.step();
00419 if (it->first.back() == m_sentence_start_id)
00420 continue;
00421
00422 double numerator = get_beta_numerator(it->first);
00423 m_beta_denominator[parent(it->first)] += numerator;
00424 }
00425 p.finish();
00426 }
00427
00428 bool is_pruned(const Ngram &ngram) const
00429 {
00430 if (ngram.size() == 1)
00431 return false;
00432 return get_d2_pair(ngram).i == 0;
00433 }
00434
00442 void prune_ngram(Ngram ngram)
00443 {
00444 assert(!m_d2.empty());
00445
00446 if (ngram.size() < 2)
00447 throw bit::invalid_argument(
00448 "bit::SimpleKneser::prune_ngram() ngram order < 2");
00449
00450 FloatIntMap::iterator it = m_d2.find(ngram);
00451 if (it == m_d2.end())
00452 throw bit::invalid_argument(
00453 "bit::SimpleKneser::prune_ngram() ngram not found");
00454
00455 if (it->second.i == 0)
00456 throw bit::invalid_argument(
00457 "bit::SimpleKneser::prune_ngram() ngram already pruned");
00458
00459 while (ngram.size() > 1) {
00460 ngram.pop_back();
00461 if (ngram.size() == 1 && ngram.back() == m_sentence_start_id)
00462 break;
00463 get_d2_pair(ngram).sub(it->second);
00464 }
00465
00466
00467
00468 fprintf(stderr, "pruned %12g %12g / %d %s\n",
00469 get_d1(it->first), it->second.f, it->second.i,
00470 ngram_str(it->first).c_str());
00471
00472 it->second.i = 0;
00473 m_sum_xg_not_pruned[backoff(it->first)] -= get_count(it->first);
00474 m_sum_nonzero_xg_not_pruned[backoff(it->first)]--;
00475 m_sum_nonzero_gx_not_pruned[parent(it->first)]--;
00476 size_t len = it->first.size();
00477 while (1) {
00478 it++;
00479 if (it == m_d2.end())
00480 break;
00481 if (it->first.size() <= len)
00482 break;
00483 if (it->second.i > 0) {
00484 fprintf(stderr, "pruned %12g %12g / %d %s\n",
00485 get_d1(it->first),
00486 it->second.f, it->second.i,
00487 ngram_str(it->first).c_str());
00488 it->second.i = 0;
00489 m_sum_xg_not_pruned[backoff(it->first)] -= get_count(it->first);
00490 m_sum_nonzero_xg_not_pruned[backoff(it->first)]--;
00491 m_sum_nonzero_gx_not_pruned[parent(it->first)]--;
00492 }
00493 }
00494 }
00495
00496 void prune(float threshold)
00497 {
00498 FloatIntMap::iterator it = m_d2.end();
00499 while (it != m_d2.begin()) {
00500 it--;
00501 if (it->second.value() < threshold)
00502 prune_ngram(it->first);
00503 }
00504 }
00505
00506 Ngram parent(const Ngram &ngram) const
00507 {
00508 assert(!ngram.empty());
00509 return Ngram(ngram.begin(), ngram.end() - 1);
00510 }
00511
00512 Ngram backoff(const Ngram &ngram) const
00513 {
00514 assert(ngram.size() > 1);
00515 return Ngram(ngram.begin() + 1, ngram.end());
00516 }
00517
00518 float inter_kn(const Ngram &ngram) const
00519 {
00520 IntMap::const_iterator it = m_sum_nonzero_gx.find(ngram);
00521 if (it == m_sum_nonzero_gx.end())
00522 return 1;
00523 return it->second * get_discount(ngram.size() + 1) /
00524 get_sum_nonzero_xgx(ngram);
00525 }
00526
00527 float prob_kn_lower(Ngram ngram) const
00528 {
00529 if (ngram.empty())
00530 throw bit::invalid_argument(
00531 "bit::SimpleKneser::prob_kn_lower() empty ngram");
00532
00533 if (ngram.back() == m_sentence_start_id)
00534 throw bit::invalid_argument(
00535 "bit::SimpleKneser::prob_kn_lower() sentence start");
00536
00537 double scale = 1;
00538 double ret = 0;
00539 while (1) {
00540 scale *= inter_kn(parent(ngram));
00541 ngram.erase(ngram.begin());
00542 if (ngram.empty())
00543 break;
00544
00545 int xg = get_sum_nonzero_xg(ngram);
00546 assert(xg != 0);
00547 if (xg >= 0) {
00548 double numerator = xg - get_discount(ngram.size());
00549 double denominator = get_sum_nonzero_xgx(parent(ngram));
00550 ret += scale * numerator / denominator;
00551 }
00552 }
00553 ret += scale / m_num_events;
00554 return ret;
00555 }
00556
00557 float prob_kn(const Ngram &ngram) const
00558 {
00559 double numerator = 0;
00560 double denominator = 1;
00561 int xg = get_sum_nonzero_xg(ngram);
00562 assert(xg != 0);
00563 if (xg > 0) {
00564 numerator = xg - get_discount(ngram.size());
00565 denominator = get_sum_nonzero_xgx(parent(ngram));
00566 }
00567 return numerator / denominator + prob_kn_lower(ngram);
00568 }
00569
00570 float inter_beta(const Ngram &ngram) const
00571 {
00572 assert(!m_beta_denominator.empty());
00573 assert(!m_sum_nonzero_xg_not_pruned.empty());
00574
00575 int value = get_sum_nonzero_gx_not_pruned(ngram);
00576 if (value == 0)
00577 return 1;
00578 return value * get_beta_discount(ngram.size() + 1) /
00579 get_beta_denominator(ngram);
00580 }
00581
00582 float prob_beta_lower(Ngram ngram) const
00583 {
00584 if (ngram.empty())
00585 throw bit::invalid_argument(
00586 "bit::SimpleKneser::prob_beta_lower() empty ngram");
00587
00588 if (ngram.back() == m_sentence_start_id)
00589 throw bit::invalid_argument(
00590 "bit::SimpleKneser::prob_beta_lower() sentence start");
00591
00592 double scale = 1;
00593 double ret = 0;
00594 while (1) {
00595 scale *= inter_beta(parent(ngram));
00596 ngram.erase(ngram.begin());
00597 if (ngram.empty())
00598 break;
00599
00600 if (is_pruned(ngram))
00601 continue;
00602
00603 double numerator = get_beta_numerator(ngram)
00604 - get_beta_discount(ngram.size());
00605 assert(numerator >= 0);
00606 double denominator = get_beta_denominator(parent(ngram));
00607 ret += scale * numerator / denominator;
00608 }
00609 ret += scale / m_num_events;
00610 return ret;
00611 }
00612
00613 float prob_beta(const Ngram &ngram) const
00614 {
00615 double numerator = 0;
00616 double denominator = 1;
00617
00618 if (!is_pruned(ngram)) {
00619 numerator = get_beta_numerator(ngram) -
00620 get_beta_discount(ngram.size());
00621 denominator = get_beta_denominator(parent(ngram));
00622 }
00623 return numerator / denominator + prob_beta_lower(ngram);
00624 }
00625
00626 std::string ngram_str(const Ngram &ngram)
00627 {
00628 std::string ret;
00629 for (size_t i = 0; i < ngram.size(); i++) {
00630 if (i > 0)
00631 ret.append(" ");
00632 ret.append(m_symbol_map.at(ngram[i]));
00633 }
00634 return ret;
00635 }
00636
00637 void write_beta_arpa(FILE *file)
00638 {
00639 unsigned int max_order = 0;
00640 for (size_t o = 1;; o++) {
00641
00642 if (max_order > 0 && o > max_order)
00643 break;
00644
00645 fprintf(stderr, "writing %zd-grams (max order %d)\n", o, max_order);
00646 fprintf(file, "\n\\%zd-grams:\n", o);
00647
00648 for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00649 {
00650 if (is_pruned(it->first))
00651 continue;
00652
00653 if (it->first.size() > max_order)
00654 max_order = it->first.size();
00655
00656 if (it->first.size() != o)
00657 continue;
00658
00659 float prob = -99;
00660 if (it->first.back() != m_sentence_start_id)
00661 prob = log10(prob_beta(it->first));
00662
00663 fprintf(file, "%g\t%s", prob,
00664 ngram_str(it->first).c_str());
00665 if (get_sum_nonzero_gx_not_pruned(it->first) > 0)
00666 fprintf(file, "\t%g", log10(inter_beta(it->first)));
00667 fputs("\n", file);
00668 }
00669 }
00670 }
00671
00672 private:
00673 std::string m_sentence_start_str;
00674 int m_sentence_start_id;
00675
00676 int m_num_events;
00677 int m_progress_skip;
00678 SymbolMap m_symbol_map;
00679 FloatVec m_discounts;
00680 FloatVec m_beta_discounts;
00681 IntMap m_counts;
00682 IntMap m_sum_nonzero_xg;
00683 IntMap m_sum_nonzero_xgx;
00684 IntMap m_sum_nonzero_gx;
00685 IntMap m_sum_xg_not_pruned;
00686 IntMap m_sum_nonzero_xg_not_pruned;
00687 IntMap m_sum_nonzero_gx_not_pruned;
00688 FloatMap m_beta_denominator;
00689 FloatMap m_d1;
00690 FloatIntMap m_d2;
00691 };
00692
00693 };
00694
00695 #endif