00001 #ifndef PERPLEXITY_HH 00002 #define PERPLEXITY_HH 00003 00004 #include "lm/LM.hh" 00005 00006 namespace bit { 00007 00009 class Perplexity { 00010 public: 00011 00013 Perplexity(const LM &lm) 00014 : m_lm(&lm) 00015 { 00016 opt.word_boundary_str = "<w>"; 00017 opt.unk_str = ""; 00018 reset(); 00019 } 00020 00022 void reset() 00023 { 00024 m_start_pending = true; 00025 m_score = 0; 00026 m_num_symbols = 0; 00027 m_num_words = 0; 00028 m_num_sentences = 0; 00029 m_it = m_lm->root(); 00030 } 00031 00033 int num_symbols() const 00034 { 00035 return m_num_symbols; 00036 } 00037 00039 int num_words() const 00040 { 00041 return m_num_words; 00042 } 00043 00045 double score() const 00046 { 00047 return m_score; 00048 } 00049 00054 float cross_entropy_per_word() const 00055 { 00056 if (m_num_words == 0) 00057 throw bit::invalid_call( 00058 "bit::Perplexity::cross_entropy_per_word() no words yet"); 00059 return m_score * 3.3219280949 / m_num_words; 00060 } 00061 00071 float add_symbol(const std::string &symbol_str) 00072 { 00073 int symbol = m_lm->symbol_map().index_nothrow(symbol_str); 00074 bool unk = false; 00075 if (symbol < 0) { 00076 if (opt.unk_str.empty()) 00077 throw bit::invalid_argument( 00078 std::string("bit::Perplexity::add_symbol(): invalid symbol \"") + 00079 symbol_str + "\""); 00080 symbol = m_lm->symbol_map().index_nothrow(opt.unk_str); 00081 if (symbol < 0) { 00082 throw bit::invalid_argument( 00083 std::string("bit::Perplexity::add_symbol(): invalid unk symbol \"") 00084 + opt.unk_str + "\""); 00085 } 00086 unk = true; 00087 } 00088 00089 if (symbol == m_lm->start_symbol() && !m_start_pending) 00090 throw bit::invalid_argument( 00091 "bit::Perplexity::add_symbol() unexpected start symbol"); 00092 if (symbol != m_lm->start_symbol() && m_start_pending) 00093 throw bit::invalid_argument( 00094 "bit::Perplexity::add_symbol() expected start symbol but got \"" + 00095 symbol_str + "\""); 00096 00097 float score = m_lm->walk(m_it, symbol); 00098 00099 if (symbol == m_lm->start_symbol()) { 00100 m_start_pending = false; 00101 assert(score == 0); 00102 } 00103 else if (symbol != m_lm->end_symbol() && !unk) { 00104 m_num_symbols++; 00105 } 00106 00107 if (!unk) 00108 m_score += score; 00109 00110 if (symbol_str == opt.word_boundary_str) 00111 m_num_words++; 00112 00113 if (symbol == m_lm->end_symbol()) { 00114 m_start_pending = true; 00115 m_num_sentences++; 00116 m_num_words--; 00117 } 00118 00119 return score; 00120 } 00121 00122 public: 00123 struct { 00127 std::string word_boundary_str; 00128 00130 std::string unk_str; 00131 } opt; 00132 00133 private: 00134 00136 bool m_start_pending; 00137 00139 double m_score; 00140 00142 int m_num_symbols; 00143 00145 int m_num_words; 00146 00148 int m_num_sentences; 00149 00151 const LM *m_lm; 00152 00154 LM::Iterator m_it; 00155 00156 }; 00157 00158 }; 00159 00160 #endif /* PERPLEXITY_HH */