SimpleKneser.hh

Go to the documentation of this file.
00001 #ifndef SIMPLEKNESER_HH
00002 #define SIMPLEKNESER_HH
00003 
00004 #include "lm/SymbolMap.hh"
00005 #include "str/str.hh"
00006 #include <map>
00007 #include <vector>
00008 #include "util/Progress.hh"
00009 #include "bit/exceptions.hh"
00010 
00011 namespace bit {
00012 
00013   class SimpleKneser {
00014   public:
00015 
00016     typedef SymbolMap<std::string,int> SymbolMap;
00017     typedef std::vector<int> Ngram;
00018     typedef std::vector<std::string> StrNgram;
00019     typedef std::vector<int> IntVec;
00020     typedef std::vector<float> FloatVec;
00021 
00022 //     struct NgramCompare {
00023 //       bool operator()(const Ngram &a, const Ngram &b) const
00024 //       {
00025 //         if (a.size() < b.size())
00026 //           return true;
00027 //         if (a.size() > b.size())
00028 //           return false;
00029 //         return a < b;
00030 //       }
00031 //     };
00032 
00033     // d2 measure needs float nominator and int denominator.  For
00034     // pruned ngrams the integer denominator is zero.
00035     struct FloatInt {
00036       FloatInt() : f(0), i(0) { }
00037       FloatInt(float f, int i) : f(f), i(i) { }
00038       float f;
00039       int i;
00040 
00041       void add(const FloatInt &value)
00042       {
00043         f += value.f;
00044         i += value.i;
00045         assert(i > 0);
00046       }
00047 
00048       void sub(const FloatInt &value)
00049       {
00050         f -= value.f;
00051         i -= value.i;
00052         assert(i > 0);
00053       }
00054       
00055       float value() const
00056       {
00057         return f / i;
00058       }
00059 
00060     };
00061 
00062     // Maps used for caching statistics 
00063     typedef std::map<Ngram, int> IntMap;
00064     typedef std::map<Ngram, float> FloatMap;
00065     typedef std::map<Ngram, FloatInt> FloatIntMap;
00066 
00067     SimpleKneser()
00068     {
00069       init();
00070     }
00071 
00072     SimpleKneser(const std::string &str)
00073     {
00074       init();
00075       StrNgram str_ngram = str::split(str, " \t", true);
00076       for (size_t i = 0; i < str_ngram.size(); i++)
00077         m_symbol_map.insert(str_ngram[i]);
00078       m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00079       m_num_events = m_symbol_map.size() - 1;
00080     }
00081 
00082     void init()
00083     {
00084       m_progress_skip = 7351;
00085       m_sentence_start_str = "<s>";
00086       m_sentence_start_id = -1;
00087     }
00088 
00089     void set_discounts(const std::string &str)
00090     {
00091       m_discounts = str::float_vec(str);
00092       if (m_discounts.empty())
00093         throw bit::invalid_argument(
00094           "bit::SimpleKneser::set_discounts() empty vector");
00095       // Dummy value to get unigram discount on index 1
00096       m_discounts.insert(m_discounts.begin(), -1e30);
00097 
00098       m_beta_discounts = m_discounts;
00099       for (size_t i = m_beta_discounts.size() - 1; i > 0; i--)
00100         m_beta_discounts[i-1] *= m_beta_discounts[i];
00101     }
00102 
00103     float get_discount(unsigned int order) const
00104     {
00105       if (m_discounts.empty())
00106         throw bit::invalid_call(
00107           "bit::SimpleKneser::get_discount() discounts not set");
00108       if (order == 0)
00109         throw bit::invalid_argument(
00110           "bit::SimpleKneser::get_discount() trying to get zero order");
00111       if (order >= m_discounts.size())
00112         return m_discounts.back();
00113       return m_discounts.at(order);
00114     }
00115 
00116     float get_beta_discount(unsigned int order) const
00117     {
00118       if (m_beta_discounts.empty())
00119         throw bit::invalid_call(
00120           "bit::SimpleKneser::get_beta_discount() beta_discounts not set");
00121       if (order == 0)
00122         throw bit::invalid_argument(
00123           "bit::SimpleKneser::get_beta_discount() trying to get zero order");
00124       if (order >= m_beta_discounts.size())
00125         return m_beta_discounts.back();
00126       return m_beta_discounts.at(order);
00127     }
00128 
00129     Ngram ngram(const std::string &str) const
00130     {
00131       StrNgram str_ngram = str::split(str, " \t", true);
00132       Ngram ngram(str_ngram.size());
00133       for (size_t i = 0; i < str_ngram.size(); i++)
00134         ngram[i] = m_symbol_map.index(str_ngram[i]);
00135       return ngram;
00136     }
00137 
00138     int get_count(const Ngram &ngram) const
00139     {
00140       IntMap::const_iterator it = m_counts.find(ngram);
00141       if (it == m_counts.end())
00142         throw bit::invalid_argument(
00143           "bit::SimpleKneser::get_count() ngram not found");
00144       return it->second;
00145     }
00146 
00147     int get_sum_nonzero_xg(const Ngram &ngram) const
00148     {
00149       IntMap::const_iterator it = m_sum_nonzero_xg.find(ngram);
00150       if (it == m_sum_nonzero_xg.end())
00151         return -1;
00152       return it->second;
00153     }
00154 
00155     int get_sum_nonzero_xgx(const Ngram &ngram) const
00156     {
00157       IntMap::const_iterator it = m_sum_nonzero_xgx.find(ngram);
00158       if (it == m_sum_nonzero_xgx.end())
00159         throw bit::invalid_argument(
00160           "bit::SimpleKneser::get_sum_nonzero_xgx() ngram not found");
00161       return it->second;
00162     }
00163 
00164     int get_sum_nonzero_gx(const Ngram &ngram) const
00165     {
00166       IntMap::const_iterator it = m_sum_nonzero_gx.find(ngram);
00167       if (it == m_sum_nonzero_gx.end())
00168         throw bit::invalid_argument(
00169           "bit::SimpleKneser::get_sum_nonzero_gx() ngram not found");
00170       assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00171       return it->second;
00172     }
00173 
00174     int get_sum_xg_not_pruned(const Ngram &ngram) const
00175     {
00176       if (!ngram.empty() && ngram.back() == m_sentence_start_id)
00177         throw bit::invalid_argument(
00178           "bit::SimpleKneser::get_sum_xg_not_pruned() sentence start");
00179       IntMap::const_iterator it = m_sum_xg_not_pruned.find(ngram);
00180       if (it == m_sum_xg_not_pruned.end())
00181         return 0;
00182       return it->second;
00183     }
00184 
00185     int get_sum_nonzero_xg_not_pruned(const Ngram &ngram) const
00186     {
00187       if (!ngram.empty() && ngram.back() == m_sentence_start_id)
00188         throw bit::invalid_argument(
00189           "bit::SimpleKneser::get_sum_nonzero_xg_not_pruned() sentence start");
00190       IntMap::const_iterator it = m_sum_nonzero_xg_not_pruned.find(ngram);
00191       if (it == m_sum_nonzero_xg_not_pruned.end())
00192         return 0;
00193       return it->second;
00194     }
00195 
00196     int get_sum_nonzero_gx_not_pruned(const Ngram &ngram) const
00197     {
00198       IntMap::const_iterator it = m_sum_nonzero_gx_not_pruned.find(ngram);
00199       if (it == m_sum_nonzero_gx_not_pruned.end())
00200         return 0;
00201       return it->second;
00202     }
00203 
00204     float get_d1(const Ngram &ngram) const
00205     {
00206       FloatMap::const_iterator it = m_d1.find(ngram);
00207       if (it == m_d1.end())
00208         throw bit::invalid_argument(
00209           "bit::SimpleKneser::get_d1() ngram not found");
00210       assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00211       return it->second;
00212     }
00213 
00214     float get_d2(const Ngram &ngram) const
00215     {
00216       FloatInt value = get_d2_pair(ngram);
00217       return value.f / value.i;
00218     }
00219 
00220     const FloatInt &get_d2_pair(const Ngram &ngram) const
00221     {
00222       FloatIntMap::const_iterator it = m_d2.find(ngram);
00223       if (it == m_d2.end())
00224         throw bit::invalid_argument(
00225           "bit::SimpleKneser::get_d2_pair() ngram not found");
00226       assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00227       return it->second;
00228     }
00229 
00230     FloatInt &get_d2_pair(const Ngram &ngram) 
00231     {
00232       FloatIntMap::iterator it = m_d2.find(ngram);
00233       if (it == m_d2.end())
00234         throw bit::invalid_argument(
00235           "bit::SimpleKneser::get_d2_pair() ngram not found");
00236       assert(it->first.empty() || it->first.back() != m_sentence_start_id);
00237       return it->second;
00238     }
00239 
00240     float get_beta_numerator(const Ngram &ngram) const
00241     {
00242       if (is_pruned(ngram))
00243         return 0;
00244       return get_count(ngram) - get_sum_xg_not_pruned(ngram) +
00245         get_beta_discount(ngram.size() + 1) * 
00246         get_sum_nonzero_xg_not_pruned(ngram);
00247     }
00248 
00249     float get_beta_denominator(const Ngram &ngram) const
00250     {
00251       FloatMap::const_iterator it = m_beta_denominator.find(ngram);
00252       if (it == m_beta_denominator.end())
00253         throw bit::invalid_argument(
00254           "bit::SimpleKneser::get_beta_denominator() ngram not found");
00255       return it->second;
00256     }
00257 
00258     void add_count(const Ngram &ngram, int count)
00259     {
00260       m_counts[ngram] += count;
00261     }
00262 
00263     void read_counts(FILE *file)
00264     {
00265       std::string line;
00266       StrNgram str_ngram;
00267       Progress p(m_progress_skip);
00268       p.set_report_string("reading counts:");
00269       while (str::read_line(line, file, true)) {
00270         str_ngram = str::split(line, " \t", true);
00271         if (str_ngram.size() < 2)
00272           throw bit::io_error(
00273             "bit::SimpleKneser::read_counts(): invalid line");
00274         int count = str::str2long(str_ngram.back());
00275         str_ngram.pop_back();
00276 
00277         Ngram ngram(str_ngram.size());
00278         for (size_t i = 0; i < str_ngram.size(); i++)
00279           ngram[i] = m_symbol_map.insert(str_ngram[i]);
00280 
00281         m_counts[ngram] += count;
00282         p.step();
00283       }
00284       p.finish();
00285 
00286       assert(m_sentence_start_id < 0);
00287       m_sentence_start_id = m_symbol_map.index(m_sentence_start_str);
00288       m_num_events = m_symbol_map.size() - 1;
00289     }
00290 
00291     void compute_modified_counts()
00292     {
00293       assert(!m_counts.empty());
00294       Progress p(m_progress_skip, m_counts.size() * 3);
00295       p.set_report_string("mod counts:");
00296 
00297       for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00298       {
00299         p.step();
00300         if (it->first.back() == m_sentence_start_id)
00301           continue;
00302         if (it->first.size() > 1) {
00303           m_sum_nonzero_xg[backoff(it->first)]++;
00304           m_sum_xg_not_pruned[backoff(it->first)] += it->second;
00305           m_sum_nonzero_xg_not_pruned[backoff(it->first)]++;
00306         }
00307         m_sum_nonzero_gx[parent(it->first)]++;
00308       }
00309 
00310       for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00311       {
00312         p.step();
00313         if (m_sum_nonzero_xg[it->first] == 0)
00314           m_sum_nonzero_xg[it->first] = it->second;
00315       }
00316 
00317       for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00318       {
00319         p.step();
00320         if (it->first.back() == m_sentence_start_id)
00321           continue;
00322         m_sum_nonzero_xgx[parent(it->first)] += get_sum_nonzero_xg(it->first);
00323       }
00324 
00325       p.finish();
00326 
00327       m_sum_nonzero_gx_not_pruned = m_sum_nonzero_gx;
00328     }
00329 
00330     float ngram_prob(Ngram ngram) const
00331     {
00332       double ret = 1;
00333       while (!ngram.empty()) {
00334         if (ngram.size() == 1 && ngram.back() == m_sentence_start_id)
00335           break;
00336         ret *= prob_kn(ngram);
00337         ngram.pop_back();
00338       }
00339       return ret;
00340     }
00341 
00342     void compute_d1()
00343     {
00344       double min_d1 = 1e100;
00345       double max_d1 = -1e100;
00346 
00347       assert(!m_sum_nonzero_xg.empty());
00348       Progress p(m_progress_skip, m_counts.size());
00349       p.set_report_string("d1:");
00350 
00351       for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00352       {
00353         p.step();
00354         if (it->first.back() == m_sentence_start_id)
00355           continue;
00356         double orig = prob_kn(it->first);
00357         double lower = prob_kn_lower(it->first);
00358 
00359         double d1 = ngram_prob(it->first) * log10(orig / lower);
00360         m_d1[it->first] = d1;
00361 
00362         if (d1 < min_d1)
00363           min_d1 = d1;
00364         if (d1 > max_d1)
00365           max_d1 = d1;
00366 
00367         if (0) {
00368           fprintf(stderr, "d1 = %12g  %s\n", m_d1[it->first],
00369                   ngram_str(it->first).c_str());
00370         }
00371         
00372       }
00373       p.finish();
00374 
00375       fprintf(stderr, "min_d1 = %g\n", min_d1);
00376       fprintf(stderr, "max_d1 = %g\n", max_d1);
00377     }
00378 
00379     void compute_d2()
00380     {
00381       assert(!m_d1.empty());
00382       Progress p(m_progress_skip, m_counts.size());
00383       p.set_report_string("d2:");
00384 
00385       FloatMap::iterator it = m_d1.end(); 
00386       while (it != m_d1.begin()) {
00387         it--;
00388         p.step();
00389 
00390         FloatInt value(it->second, 1);
00391         FloatInt &d2 = m_d2[it->first];
00392 
00393         d2.add(value);
00394         if (it->first.size() > 1)
00395           m_d2[parent(it->first)].add(d2);
00396 
00397         if (0) {
00398           fprintf(stderr, "%12g / %d %s\n", d2.f, d2.i, 
00399                   ngram_str(it->first).c_str());
00400         }
00401       }
00402       p.finish();
00403     }
00404 
00405     void compute_beta_denominator()
00406     {
00407       fprintf(stderr, "FATAL: compute_beta_denominator(): "
00408               "denominator is computed incorrecly\n");
00409       abort();
00410 
00411       assert(!m_sum_xg_not_pruned.empty());
00412       Progress p(m_progress_skip, m_counts.size());
00413       p.set_report_string("beta denominator:");
00414       m_beta_denominator.clear();
00415 
00416       for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++)
00417       {
00418         p.step();
00419         if (it->first.back() == m_sentence_start_id)
00420           continue;
00421 
00422         double numerator = get_beta_numerator(it->first);
00423         m_beta_denominator[parent(it->first)] += numerator;
00424       }
00425       p.finish();
00426     }
00427 
00428     bool is_pruned(const Ngram &ngram) const
00429     {
00430       if (ngram.size() == 1)
00431         return false;
00432       return get_d2_pair(ngram).i == 0;
00433     }
00434 
00442     void prune_ngram(Ngram ngram)
00443     {
00444       assert(!m_d2.empty());
00445 
00446       if (ngram.size() < 2)
00447         throw bit::invalid_argument(
00448           "bit::SimpleKneser::prune_ngram() ngram order < 2");
00449 
00450       FloatIntMap::iterator it = m_d2.find(ngram);
00451       if (it == m_d2.end())
00452         throw bit::invalid_argument(
00453           "bit::SimpleKneser::prune_ngram() ngram not found");
00454       
00455       if (it->second.i == 0)
00456         throw bit::invalid_argument(
00457           "bit::SimpleKneser::prune_ngram() ngram already pruned");
00458 
00459       while (ngram.size() > 1) {
00460         ngram.pop_back();
00461         if (ngram.size() == 1 && ngram.back() == m_sentence_start_id)
00462           break;
00463         get_d2_pair(ngram).sub(it->second);
00464       }
00465 
00466       // Mark ngram and its children pruned
00467       //
00468       fprintf(stderr, "pruned %12g %12g / %d %s\n", 
00469               get_d1(it->first), it->second.f, it->second.i,
00470               ngram_str(it->first).c_str());
00471 
00472       it->second.i = 0;
00473       m_sum_xg_not_pruned[backoff(it->first)] -= get_count(it->first);
00474       m_sum_nonzero_xg_not_pruned[backoff(it->first)]--;
00475       m_sum_nonzero_gx_not_pruned[parent(it->first)]--;
00476       size_t len = it->first.size();
00477       while (1) {
00478         it++;
00479         if (it == m_d2.end())
00480           break;
00481         if (it->first.size() <= len)
00482           break;
00483         if (it->second.i > 0) {
00484           fprintf(stderr, "pruned %12g %12g / %d %s\n", 
00485                   get_d1(it->first),
00486                   it->second.f, it->second.i,
00487                   ngram_str(it->first).c_str());
00488           it->second.i = 0;
00489           m_sum_xg_not_pruned[backoff(it->first)] -= get_count(it->first);
00490           m_sum_nonzero_xg_not_pruned[backoff(it->first)]--;
00491           m_sum_nonzero_gx_not_pruned[parent(it->first)]--;
00492         }
00493       }
00494     }
00495 
00496     void prune(float threshold)
00497     {
00498       FloatIntMap::iterator it = m_d2.end();
00499       while (it != m_d2.begin()) {
00500         it--;
00501         if (it->second.value() < threshold)
00502           prune_ngram(it->first);
00503       }
00504     }
00505 
00506     Ngram parent(const Ngram &ngram) const
00507     {
00508       assert(!ngram.empty());
00509       return Ngram(ngram.begin(), ngram.end() - 1);
00510     }
00511 
00512     Ngram backoff(const Ngram &ngram) const
00513     {
00514       assert(ngram.size() > 1);
00515       return Ngram(ngram.begin() + 1, ngram.end());
00516     }
00517 
00518     float inter_kn(const Ngram &ngram) const
00519     {
00520       IntMap::const_iterator it = m_sum_nonzero_gx.find(ngram);
00521       if (it == m_sum_nonzero_gx.end())
00522         return 1;
00523       return it->second * get_discount(ngram.size() + 1) / 
00524         get_sum_nonzero_xgx(ngram);
00525     }
00526 
00527     float prob_kn_lower(Ngram ngram) const
00528     {
00529       if (ngram.empty())
00530         throw bit::invalid_argument(
00531           "bit::SimpleKneser::prob_kn_lower() empty ngram");
00532 
00533       if (ngram.back() == m_sentence_start_id)
00534         throw bit::invalid_argument(
00535           "bit::SimpleKneser::prob_kn_lower() sentence start");
00536 
00537       double scale = 1;
00538       double ret = 0;
00539       while (1) {
00540         scale *= inter_kn(parent(ngram));
00541         ngram.erase(ngram.begin());
00542         if (ngram.empty())
00543           break;
00544 
00545         int xg = get_sum_nonzero_xg(ngram);
00546         assert(xg != 0);
00547         if (xg >= 0) {
00548           double numerator = xg - get_discount(ngram.size());
00549           double denominator = get_sum_nonzero_xgx(parent(ngram));
00550           ret += scale * numerator / denominator;
00551         }
00552       }
00553       ret += scale / m_num_events;
00554       return ret;
00555     }
00556 
00557     float prob_kn(const Ngram &ngram) const
00558     {
00559       double numerator = 0;
00560       double denominator = 1;
00561       int xg = get_sum_nonzero_xg(ngram);
00562       assert(xg != 0);
00563       if (xg > 0) {
00564         numerator = xg - get_discount(ngram.size());
00565         denominator = get_sum_nonzero_xgx(parent(ngram));
00566       }
00567       return numerator / denominator + prob_kn_lower(ngram);
00568     }
00569 
00570     float inter_beta(const Ngram &ngram) const
00571     {
00572       assert(!m_beta_denominator.empty());
00573       assert(!m_sum_nonzero_xg_not_pruned.empty());
00574 
00575       int value = get_sum_nonzero_gx_not_pruned(ngram);
00576       if (value == 0)
00577         return 1;
00578       return value * get_beta_discount(ngram.size() + 1) / 
00579          get_beta_denominator(ngram);
00580     }
00581 
00582     float prob_beta_lower(Ngram ngram) const
00583     {
00584       if (ngram.empty())
00585         throw bit::invalid_argument(
00586           "bit::SimpleKneser::prob_beta_lower() empty ngram");
00587 
00588       if (ngram.back() == m_sentence_start_id)
00589         throw bit::invalid_argument(
00590           "bit::SimpleKneser::prob_beta_lower() sentence start");
00591 
00592       double scale = 1;
00593       double ret = 0;
00594       while (1) {
00595         scale *= inter_beta(parent(ngram));
00596         ngram.erase(ngram.begin());
00597         if (ngram.empty())
00598           break;
00599 
00600         if (is_pruned(ngram))
00601           continue;
00602 
00603         double numerator = get_beta_numerator(ngram) 
00604           - get_beta_discount(ngram.size());
00605         assert(numerator >= 0);
00606         double denominator = get_beta_denominator(parent(ngram));
00607         ret += scale * numerator / denominator;
00608       }
00609       ret += scale / m_num_events;
00610       return ret;
00611     }
00612 
00613     float prob_beta(const Ngram &ngram) const
00614     {
00615       double numerator = 0;
00616       double denominator = 1;
00617 
00618       if (!is_pruned(ngram)) {
00619         numerator = get_beta_numerator(ngram) - 
00620           get_beta_discount(ngram.size());
00621         denominator = get_beta_denominator(parent(ngram));
00622       }
00623       return numerator / denominator + prob_beta_lower(ngram);
00624     }
00625 
00626     std::string ngram_str(const Ngram &ngram)
00627     {
00628       std::string ret;
00629       for (size_t i = 0; i < ngram.size(); i++) {
00630         if (i > 0)
00631           ret.append(" ");
00632         ret.append(m_symbol_map.at(ngram[i]));
00633       }
00634       return ret;
00635     }
00636 
00637     void write_beta_arpa(FILE *file)
00638     {
00639       unsigned int max_order = 0;
00640       for (size_t o = 1;; o++) {
00641       
00642         if (max_order > 0 && o > max_order)
00643           break;
00644 
00645         fprintf(stderr, "writing %zd-grams (max order %d)\n", o, max_order);
00646         fprintf(file, "\n\\%zd-grams:\n", o);
00647 
00648         for (IntMap::iterator it = m_counts.begin(); it != m_counts.end(); it++) 
00649         {
00650           if (is_pruned(it->first))
00651             continue;
00652 
00653           if (it->first.size() > max_order)
00654             max_order = it->first.size();
00655 
00656           if (it->first.size() != o)
00657             continue;
00658 
00659           float prob = -99;
00660           if (it->first.back() != m_sentence_start_id)
00661             prob = log10(prob_beta(it->first));
00662 
00663           fprintf(file, "%g\t%s", prob,
00664                   ngram_str(it->first).c_str());
00665           if (get_sum_nonzero_gx_not_pruned(it->first) > 0)
00666             fprintf(file, "\t%g", log10(inter_beta(it->first)));
00667           fputs("\n", file);
00668         }
00669       }
00670     }
00671 
00672   private:
00673     std::string m_sentence_start_str;
00674     int m_sentence_start_id;
00675 
00676     int m_num_events;
00677     int m_progress_skip;
00678     SymbolMap m_symbol_map;
00679     FloatVec m_discounts;
00680     FloatVec m_beta_discounts;
00681     IntMap m_counts;
00682     IntMap m_sum_nonzero_xg;
00683     IntMap m_sum_nonzero_xgx;
00684     IntMap m_sum_nonzero_gx;
00685     IntMap m_sum_xg_not_pruned;
00686     IntMap m_sum_nonzero_xg_not_pruned;
00687     IntMap m_sum_nonzero_gx_not_pruned;
00688     FloatMap m_beta_denominator;
00689     FloatMap m_d1;
00690     FloatIntMap m_d2;
00691   };
00692 
00693 };
00694 
00695 #endif /* SIMPLEKNESER_HH */

Generated on Mon Jan 8 15:51:03 2007 for bit by  doxygen 1.4.6