42 #include "../proto/data.pb.h"
43 #include "../proto/model.pb.h"
49 using confusion_learning::CandidateMessage;
50 using confusion_learning::CandidateSetMessage;
51 using confusion_learning::FeatureVecMessage;
52 using confusion_learning::FeatureMessage;
53 using confusion_learning::ScoreMessage;
57 using std::insert_iterator;
58 using std::ostream_iterator;
60 using std::unordered_map;
61 using std::shared_ptr;
69 if (m.has_source_key()) {
72 if (m.has_reference_string()) {
76 if (m.has_gold_index()) {
79 if (m.has_best_scoring_index()) {
83 int num_candidates = m.candidate_size();
84 int num_candidates_to_read =
85 max_candidates < 0 ? num_candidates :
86 (max_candidates > num_candidates ? num_candidates : max_candidates);
87 for (
int i = 0; i < num_candidates_to_read; ++i) {
88 const CandidateMessage &candidate_msg = m.candidate(i);
89 const FeatureVecMessage &feature_vec_msg = candidate_msg.feats();
94 for (
int j = 0; j < feature_vec_msg.feature_size(); ++j) {
95 const FeatureMessage &feature_msg = feature_vec_msg.feature(j);
96 if (feature_msg.has_name() && ! feature_msg.name().empty()) {
103 bool set_loss =
false;
105 double baseline_score = 0.0;
106 for (
int score_index = 0;
107 score_index < candidate_msg.score_size();
109 const ScoreMessage &score_msg = candidate_msg.score(score_index);
110 switch (score_msg.type()) {
111 case ScoreMessage::LOSS:
112 loss = score_msg.score();
115 case ScoreMessage::SYSTEM_SCORE:
120 baseline_score = score_msg.score();
122 case ScoreMessage::OUTPUT_SCORE:
127 int num_words = CountTokens(candidate_msg.raw_data());
130 cerr <<
"CandidateSetProtoReader: warning: computing loss by tokenizing"
131 <<
" and counting." << endl;
132 loss = ComputeLoss(set, candidate_msg.raw_data());
135 shared_ptr<Candidate> candidate(
new Candidate(i, loss, baseline_score,
137 candidate_msg.raw_data(),
138 features, symbolic_features));
145 const string &candidate_raw_data) {
147 vector<string> ref_toks;
149 vector<string> candidate_toks;
150 tokenizer_.
Tokenize(candidate_raw_data, candidate_toks);
151 std::set<string> ref_toks_set;
152 typedef vector<string>::const_iterator const_it;
153 for (const_it it = ref_toks.begin(); it != ref_toks.end(); ++it) {
154 ref_toks_set.insert(*it);
156 std::set<string> candidate_toks_set;
157 for (const_it it = candidate_toks.begin(); it != candidate_toks.end();
159 candidate_toks_set.insert(*it);
161 std::set<string> intersection;
162 insert_iterator<std::set<string> > ii(intersection, intersection.begin());
163 set_intersection(ref_toks_set.begin(), ref_toks_set.end(),
164 candidate_toks_set.begin(), candidate_toks_set.end(),
169 ostream_iterator<string> tab_delimited(cout,
"\n\t");
170 cout <<
"Set of ref toks:\n\t";
171 copy(ref_toks_set.begin(), ref_toks_set.end(), tab_delimited);
173 cout <<
"Set of candidate toks:\n\t";
174 copy(candidate_toks_set.begin(), candidate_toks_set.end(), tab_delimited);
176 cout <<
"Intersection:\n\t";
177 copy(intersection.begin(), intersection.end(), tab_delimited);
181 double loss = intersection.size() / (double)ref_toks_set.size();
184 cout <<
"Intersection size is " << intersection.size()
185 <<
" and there are " << ref_toks_set.size()
186 <<
" ref toks, so loss is " << loss <<
"." << endl;
const string & reference_string() const
void Read(const CandidateSetMessage &m, CandidateSet &set)
Fills in the specified CandidateSet based on the specified CandidateSetMessage, crucially constructin...
void set_reference_string(const string &reference_string)
Provides the reranker::Symbols interface as well as the reranker::StaticSymbolTable implementation...
void Tokenize(const string &s, vector< string > &toks, const char *delimiters=" \t") const
Tokenizes the specified string, depositing the results into the specified vector. ...
void AddCandidate(shared_ptr< Candidate > candidate)
void set_training_key(const string &training_key)
void set_gold_index(size_t index)
A class to hold a set of candidates, either for training or test.
A class to represent a candidate in a set of candidates that constitutes a training instance for a re...
void set_reference_string_token_count(int reference_string_token_count)
Reads CandidateSetMessage instances and converts them to reranker::CandidateSet instances.
V IncrementWeight(const K &uid, V by)
Increments the weight of the specified feature by the specified amount.
void set_best_scoring_index(size_t index)