Reranker Framework (ReFr)
Reranking framework for structure prediction and discriminative language modeling
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
candidate-set-proto-reader.C
Go to the documentation of this file.
1 // Copyright 2012, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 // -----------------------------------------------------------------------------
30 //
31 //
36 
37 #include <algorithm>
38 #include <iterator>
39 #include <set>
40 #include <memory>
41 
42 #include "../proto/data.pb.h"
43 #include "../proto/model.pb.h"
44 #include "symbol-table.H"
46 
47 #define DEBUG 0
48 
49 using confusion_learning::CandidateMessage;
50 using confusion_learning::CandidateSetMessage;
51 using confusion_learning::FeatureVecMessage;
52 using confusion_learning::FeatureMessage;
53 using confusion_learning::ScoreMessage;
54 
55 using std::cout;
56 using std::endl;
57 using std::insert_iterator;
58 using std::ostream_iterator;
59 using std::string;
60 using std::unordered_map;
61 using std::shared_ptr;
62 
63 namespace reranker {
64 
65 void
66 CandidateSetProtoReader::Read(const CandidateSetMessage &m,
67  int max_candidates,
68  CandidateSet &set) {
69  if (m.has_source_key()) {
70  set.set_training_key(m.source_key());
71  }
72  if (m.has_reference_string()) {
73  set.set_reference_string(m.reference_string());
74  set.set_reference_string_token_count(CountTokens(set.reference_string()));
75  }
76  if (m.has_gold_index()) {
77  set.set_gold_index(m.gold_index());
78  }
79  if (m.has_best_scoring_index()) {
80  set.set_best_scoring_index(m.best_scoring_index());
81  }
82 
83  int num_candidates = m.candidate_size();
84  int num_candidates_to_read =
85  max_candidates < 0 ? num_candidates :
86  (max_candidates > num_candidates ? num_candidates : max_candidates);
87  for (int i = 0; i < num_candidates_to_read; ++i) {
88  const CandidateMessage &candidate_msg = m.candidate(i);
89  const FeatureVecMessage &feature_vec_msg = candidate_msg.feats();
90 
91  FeatureVector<string,double> symbolic_features;
93 
94  for (int j = 0; j < feature_vec_msg.feature_size(); ++j) {
95  const FeatureMessage &feature_msg = feature_vec_msg.feature(j);
96  if (feature_msg.has_name() && ! feature_msg.name().empty()) {
97  symbolic_features.IncrementWeight(feature_msg.name(),
98  feature_msg.value());
99  } else {
100  features.IncrementWeight(feature_msg.id(), feature_msg.value());
101  }
102  }
103  bool set_loss = false;
104  double loss = 0.0;
105  double baseline_score = 0.0;
106  for (int score_index = 0;
107  score_index < candidate_msg.score_size();
108  ++score_index) {
109  const ScoreMessage &score_msg = candidate_msg.score(score_index);
110  switch (score_msg.type()) {
111  case ScoreMessage::LOSS:
112  loss = score_msg.score();
113  set_loss = true;
114  break;
115  case ScoreMessage::SYSTEM_SCORE:
116  // TODO(dbikel): Deal with the fact that there could be multiple
117  // system scores one day, perhaps simply storing
118  // them in an array (right now, we just take the
119  // last one).
120  baseline_score = score_msg.score();
121  break;
122  case ScoreMessage::OUTPUT_SCORE:
123  break;
124  }
125  }
126 
127  int num_words = CountTokens(candidate_msg.raw_data());
128 
129  if (!set_loss) {
130  cerr << "CandidateSetProtoReader: warning: computing loss by tokenizing"
131  << " and counting." << endl;
132  loss = ComputeLoss(set, candidate_msg.raw_data());
133  }
134 
135  shared_ptr<Candidate> candidate(new Candidate(i, loss, baseline_score,
136  num_words,
137  candidate_msg.raw_data(),
138  features, symbolic_features));
139  set.AddCandidate(candidate);
140  }
141 }
142 
143 double
144 CandidateSetProtoReader::ComputeLoss(CandidateSet &set,
145  const string &candidate_raw_data) {
146  // For now, find loss for candidate by doing "position-independent WER".
147  vector<string> ref_toks;
148  tokenizer_.Tokenize(set.reference_string(), ref_toks);
149  vector<string> candidate_toks;
150  tokenizer_.Tokenize(candidate_raw_data, candidate_toks);
151  std::set<string> ref_toks_set;
152  typedef vector<string>::const_iterator const_it;
153  for (const_it it = ref_toks.begin(); it != ref_toks.end(); ++it) {
154  ref_toks_set.insert(*it);
155  }
156  std::set<string> candidate_toks_set;
157  for (const_it it = candidate_toks.begin(); it != candidate_toks.end();
158  ++it) {
159  candidate_toks_set.insert(*it);
160  }
161  std::set<string> intersection;
162  insert_iterator<std::set<string> > ii(intersection, intersection.begin());
163  set_intersection(ref_toks_set.begin(), ref_toks_set.end(),
164  candidate_toks_set.begin(), candidate_toks_set.end(),
165  ii);
166 
167 
168  if (DEBUG) {
169  ostream_iterator<string> tab_delimited(cout, "\n\t");
170  cout << "Set of ref toks:\n\t";
171  copy(ref_toks_set.begin(), ref_toks_set.end(), tab_delimited);
172  cout << endl;
173  cout << "Set of candidate toks:\n\t";
174  copy(candidate_toks_set.begin(), candidate_toks_set.end(), tab_delimited);
175  cout << endl;
176  cout << "Intersection:\n\t";
177  copy(intersection.begin(), intersection.end(), tab_delimited);
178  cout << endl;
179  }
180 
181  double loss = intersection.size() / (double)ref_toks_set.size();
182 
183  if (DEBUG) {
184  cout << "Intersection size is " << intersection.size()
185  << " and there are " << ref_toks_set.size()
186  << " ref toks, so loss is " << loss << "." << endl;
187  }
188 
189  return loss;
190 }
191 
192 } // namespace reranker
const string & reference_string() const
void Read(const CandidateSetMessage &m, CandidateSet &set)
Fills in the specified CandidateSet based on the specified CandidateSetMessage, crucially constructin...
void set_reference_string(const string &reference_string)
Provides the reranker::Symbols interface as well as the reranker::StaticSymbolTable implementation...
void Tokenize(const string &s, vector< string > &toks, const char *delimiters=" \t") const
Tokenizes the specified string, depositing the results into the specified vector. ...
Definition: tokenizer.H:62
void AddCandidate(shared_ptr< Candidate > candidate)
#define DEBUG
void set_training_key(const string &training_key)
void set_gold_index(size_t index)
float loss
Definition: hadoop-run.py:389
A class to hold a set of candidates, either for training or test.
Definition: candidate-set.H:62
A class to represent a candidate in a set of candidates that constitutes a training instance for a re...
Definition: candidate.H:60
void set_reference_string_token_count(int reference_string_token_count)
Reads CandidateSetMessage instances and converts them to reranker::CandidateSet instances.
V IncrementWeight(const K &uid, V by)
Increments the weight of the specified feature by the specified amount.
void set_best_scoring_index(size_t index)