Reranker Framework (ReFr)
Reranking framework for structure prediction and discriminative language modeling
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
piped-model-evaluator.C
Go to the documentation of this file.
1 // Copyright 2012, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 // -----------------------------------------------------------------------------
30 //
31 //
37 
38 #include <iostream>
39 #include <fstream>
40 #include <string>
41 #include <memory>
42 #include <unistd.h>
43 
44 #include "candidate-set-iterator.H"
45 #include "candidate-set-reader.H"
46 #include "model.H"
47 #include "model-reader.H"
48 
49 #define DEBUG 0
50 
51 #define PROG_NAME "piped-model-evaluator"
52 
53 #define DEFAULT_MAX_EXAMPLES -1
54 #define DEFAULT_MAX_CANDIDATES -1
55 #define DEFAULT_REPORTING_INTERVAL 1000
56 #define DEFAULT_USE_WEIGHTED_LOSS true
57 
58 // We use two levels of macros to get the string version of an int constant.
61 #define XSTR(arg) STR(arg)
62 #define STR(arg) #arg
64 
65 using namespace std;
66 using namespace reranker;
67 
68 const char *usage_msg[] = {
69  "Usage:\n",
70  PROG_NAME " -d|--devtest <devtest input file>+\n",
71  "\t[--dev-config <devtest feature extractor config file>]\n",
72  "\t[--model-files <file with model filenames>\n",
73  "\t[-u] [--no-base64]\n",
74  "\t[--max-examples <max num examples>]\n",
75  "\t[--max-candidates <max num candidates>]\n",
76  "\t[-r <reporting interval>] [ --use-weighted-loss[=][true|false] ]\n",
77  "where\n",
78  "\t<devtest input file> is the name of a stream of serialized\n",
79  "\t\tCandidateSet instances, or \"-\" for input from standard input\n",
80  "\t\t(required unless training in mapper mode)\n",
81  "\t--model-files specifies the name of a file from which to read model\n",
82  "\t\tmodel filenames (use this option for debugging; defaults to stdin)\n",
83  "\t-u specifies that the input files are uncompressed\n",
84  "\t--no-base64 specifies not to use base64 encoding/decoding\n",
85  "\t--max-examples specifies the maximum number of examples to read from\n",
86  "\t\tany input file (defaults to " XSTR(DEFAULT_MAX_EXAMPLES) ")\n",
87  "\t--max-candidates specifies the maximum number of candidates to read\n",
88  "\t\tfor any candidate set (defaults to " XSTR(DEFAULT_MAX_CANDIDATES) ")\n",
89  "\t-r specifies the interval at which the CandidateSetReader reports how\n",
90  "\t\tmany candidate sets it has read (defaults to "
92  "\t--use-weighted-loss specifies whether to weight losses on devtest\n",
93  "\t\texamples by the number of tokens in the reference, where, e.g.,\n",
94  "\t\tweighted loss is appropriate for computing WER, but not BLEU\n",
95  "\t\t(defaults to " XSTR(DEFAULT_USE_WEIGHTED_LOSS) ")\n"
96 };
97 
100 void usage() {
101  int usage_msg_len = sizeof(usage_msg)/sizeof(const char *);
102  for (int i = 0; i < usage_msg_len; ++i) {
103  cout << usage_msg[i];
104  }
105  cout.flush();
106 }
107 
108 bool check_for_required_arg(int argc, int i, string err_msg) {
109  if (i + 1 >= argc) {
110  cerr << PROG_NAME << ": error: " << err_msg << endl;
111  usage();
112  return false;
113  } else {
114  return true;
115  }
116 }
117 
118 int
119 main(int argc, char **argv) {
120  string model_file;
121  bool using_model_filenames_file = false;
122  string model_filenames_file;
123  vector<string> devtest_files;
124  string devtest_feature_extractor_config_file;
125  bool compressed = true;
126  bool use_base64 = true;
127  bool use_weighted_loss = DEFAULT_USE_WEIGHTED_LOSS;
128  string use_weighted_loss_arg_prefix = "--use-weighted-loss";
129  size_t use_weighted_loss_arg_prefix_len =
130  use_weighted_loss_arg_prefix.length();
131  int max_examples = DEFAULT_MAX_EXAMPLES;
132  int max_candidates = DEFAULT_MAX_CANDIDATES;
133  int reporting_interval = DEFAULT_REPORTING_INTERVAL;
134 
135  // Process options. The majority of code in this file is devoted to this.
136  for (int i = 1; i < argc; ++i) {
137  string arg = argv[i];
138  if (arg == "-d" || arg == "-devtest" || arg == "--devtest") {
139  string err_msg = string("no input files specified with ") + arg;
140  if (!check_for_required_arg(argc, i, err_msg)) {
141  return -1;
142  }
143  // Keep reading args until next option or until no more args.
144  ++i;
145  for ( ; i < argc; ++i) {
146  if (argv[i][0] == '-') {
147  --i;
148  break;
149  }
150  devtest_files.push_back(argv[i]);
151  }
152  } else if (arg == "-dev-config" || arg == "--dev-config") {
153  string err_msg =
154  string("no feature extractor config file specified with ") + arg;
155  if (!check_for_required_arg(argc, i, err_msg)) {
156  return -1;
157  }
158  devtest_feature_extractor_config_file = argv[++i];
159  } else if (arg == "-model-files" || arg == "--model-files") {
160  string err_msg = string("no model filenames file specified with ") + arg;
161  if (!check_for_required_arg(argc, i, err_msg)) {
162  return -1;
163  }
164  model_filenames_file = argv[++i];
165  using_model_filenames_file = true;
166  } else if (arg == "-u") {
167  compressed = false;
168  } else if (arg == "--no-base64") {
169  use_base64 = false;
170  } else if (arg == "-max-examples" || arg == "--max-examples") {
171  string err_msg = string("no arg specified with ") + arg;
172  if (!check_for_required_arg(argc, i, err_msg)) {
173  return -1;
174  }
175  max_examples = atoi(argv[++i]);
176  } else if (arg == "-max-candidates" || arg == "--max-candidates") {
177  string err_msg = string("no arg specified with ") + arg;
178  if (!check_for_required_arg(argc, i, err_msg)) {
179  return -1;
180  }
181  max_candidates = atoi(argv[++i]);
182  } else if (arg == "-r") {
183  string err_msg = string("no arg specified with ") + arg;
184  if (!check_for_required_arg(argc, i, err_msg)) {
185  return -1;
186  }
187  reporting_interval = atoi(argv[++i]);
188  } else if (arg.substr(0, use_weighted_loss_arg_prefix_len) ==
189  use_weighted_loss_arg_prefix) {
190  string use_weighted_loss_str;
191  if (arg.length() > use_weighted_loss_arg_prefix_len &&
192  arg[use_weighted_loss_arg_prefix_len] == '=') {
193  use_weighted_loss_str =
194  arg.substr(use_weighted_loss_arg_prefix_len + 1);
195  } else {
196  string err_msg =
197  string("no \"true\" or \"false\" arg specified with ") + arg;
198  if (!check_for_required_arg(argc, i, err_msg)) {
199  return -1;
200  }
201  use_weighted_loss_str = argv[++i];
202  }
203  if (use_weighted_loss_str != "true" &&
204  use_weighted_loss_str != "false") {
205  cerr << PROG_NAME << ": error: must specify \"true\" or \"false\""
206  << " with --use-weighted-loss" << endl;
207  usage();
208  return -1;
209  }
210  if (use_weighted_loss_str != "true") {
211  use_weighted_loss = false;
212  }
213  } else if (arg.size() > 0 && arg[0] == '-') {
214  cerr << PROG_NAME << ": error: unrecognized option: " << arg << endl;
215  usage();
216  return -1;
217  }
218  }
219 
220  if (devtest_files.size() == 0) {
221  cerr << PROG_NAME << ": error: must specify devtest input files when "
222  << "not in mapper mode" << endl;
223  usage();
224  return -1;
225  }
226 
227  shared_ptr<ExecutiveFeatureExtractor> devtest_efe;
228  if (devtest_feature_extractor_config_file != "") {
229  devtest_efe = ExecutiveFeatureExtractor::InitFromSpec(
230  devtest_feature_extractor_config_file);
231  }
232 
233  CandidateSetReader csr(max_examples, max_candidates, reporting_interval);
234  csr.set_verbosity(1);
235  bool reset_counters = true;
236 
237  cerr << "Reading devtest examples." << endl;
238 
239  vector<shared_ptr<CandidateSet> > devtest_examples;
240  for (vector<string>::const_iterator file_it = devtest_files.begin();
241  file_it != devtest_files.end();
242  ++file_it) {
243  csr.Read(*file_it, compressed, use_base64, reset_counters,
244  devtest_examples);
245  }
246  // Extract features for CandidateSet instances in situ.
247  for (vector<shared_ptr<CandidateSet> >::iterator it =
248  devtest_examples.begin();
249  it != devtest_examples.end();
250  ++it) {
251  devtest_efe->Extract(*(*it));
252  }
253 
254  cerr << "Done reading devtest examples." << endl;
255 
256  if (devtest_examples.size() == 0) {
257  cerr << "Could not read any devtest examples. Exiting." << endl;
258  return -1;
259  }
260 
262  CandidateSetVectorIt;
263 
264  istream *model_filenames_stream =
265  using_model_filenames_file ?
266  new ifstream(model_filenames_file.c_str()) : &cin;
267 
268  ModelReader model_reader(1);
269  while (getline(*model_filenames_stream, model_file)) {
270  cerr << "Evaluating model \"" << model_file << "\"." << endl;
271  shared_ptr<Model> model =
272  model_reader.Read(model_file, compressed, use_base64);
273  model->set_use_weighted_loss(use_weighted_loss);
274  CandidateSetVectorIt devtest_examples_it(devtest_examples);
275  model->NewEpoch(); // sets epoch to 0
276  cout << model->Evaluate(devtest_examples_it) << endl;
277 
278  // Decompile all features in devtest examples (will do nothing if there
279  // were no symbolic features to begin with; see "dont_force" below).
280  devtest_examples_it.Reset();
281  while (devtest_examples_it.HasNext()) {
282  CandidateSet &candidate_set = devtest_examples_it.Next();
283  bool dont_force = false;
284  candidate_set.DecompileFeatures(model->symbols(), true, true, dont_force);
285  }
286  }
287  if (using_model_filenames_file) {
288  delete model_filenames_stream;
289  }
290 }
#define DEFAULT_REPORTING_INTERVAL
Provides an interface and some implementations for iterating over CandidateSet instances.
Provides the ModelReader class, which can create Model instances from a file.
#define DEFAULT_MAX_EXAMPLES
void set_verbosity(int verbosity)
Sets the verbosity of this reader (mostly for debugging purposes).
const char * usage_msg[]
void DecompileFeatures(Symbols *symbols, bool clear_symbolic_features=false, bool clear_features=true, bool force=false)
Decompiles any non-symbolic features in the candidates in this candidate set.
void usage()
Class for reading streams of training or test instances, where each training or test instance is a re...
A class to hold a set of candidates, either for training or test.
Definition: candidate-set.H:62
void Read(const string &filename, bool compressed, bool use_base64, bool reset_counters, vector< shared_ptr< CandidateSet > > &examples)
Reads a stream of CandidateSet instances from the specified file or from standard input...
int main(int argc, char **argv)
#define PROG_NAME
An implementation of the CandidateSetIterator interface that is backed by an arbitrary C++ collection...
#define DEFAULT_MAX_CANDIDATES
#define XSTR(arg)
Expands the string value of the specified argument using the STR macro.
#define DEFAULT_USE_WEIGHTED_LOSS
A class for reading streams of training or test instances, where each training or test instance is a ...
bool check_for_required_arg(int argc, int i, string err_msg)
Knows how to create Model instances that have been serialized to a file.
Definition: model-reader.H:55
Reranker model interface.