Reranker Framework (ReFr)
Reranking framework for structure prediction and discriminative language modeling
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
model-merge-reducer.C
Go to the documentation of this file.
1 // Copyright 2012, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 // -----------------------------------------------------------------------------
30 //
31 //
36 
37 #include <iostream>
38 #include <string>
39 #include <stdlib.h>
40 #include "../proto/dataio.h"
41 #include "../utils/kdebug.h"
42 #include "../proto/model.pb.h"
43 #include "model-merge-reducer.H"
44 
45 namespace reranker {
46 
47 using namespace std;
48 using confusion_learning::FeatureMessage;
49 using confusion_learning::ModelMessage;
50 
51 const char* ModelInfoReducer::kModelMessageFeatureName = "__MODEL_INFO_FIELD__";
52 
53 FeatureReducer::FeatureReducer(bool uniform_mix, double mix_denominator)
54  : num_merged_(0), uniform_mix_(uniform_mix),
55  mix_denominator_(mix_denominator) {
56 }
57 
58 int FeatureReducer::Reduce(const string& feat_id, const string& value) {
59  // Decode the value as a FeatureMessage protocol buffer.
60  FeatureMessage new_message;
61  if (!messageio_.DecodeBase64(value, &new_message)) {
62  cerr << "Error decoding message: " << value.c_str() << endl;
63  }
64  int num_output = 0;
65  if (feat_id.compare(prev_feat_) != 0) {
66  // If this is a new feature (new key), then output the previous features.
67  if (!prev_feat_.empty()) {
68  // Do we want to normalize the mixture with the number of mappers which
69  // found this feature ?
70  double normalizer = uniform_mix_ ? static_cast<double>(num_merged_) :
71  mix_denominator_;
72  if (normalizer != 1.0) {
73  cur_message_.set_value(cur_message_.value() / normalizer);
74  cur_message_.set_avg_value(cur_message_.avg_value() / normalizer);
75  }
76  // Encode message and output to stdout.
77  string encoded_msg;
78  messageio_.EncodeBase64(cur_message_, &encoded_msg);
79  cout << prev_feat_.c_str() << "\t" << encoded_msg.c_str();
80  }
81  // Record the new key and clear the state.
82  prev_feat_ = feat_id;
83  cur_message_.CopyFrom(new_message);
84  num_merged_ = 1;
85  num_output = 1;
86  } else {
87  cur_message_.set_value(cur_message_.value() + new_message.value());
88  cur_message_.set_avg_value(cur_message_.avg_value() + new_message.avg_value());
89  cur_message_.set_count(cur_message_.count() + new_message.count());
90  num_merged_++;
91  }
92  // Update state.
93  return num_output;
94 }
95 
97  if (!prev_feat_.empty()) {
98  double normalizer = uniform_mix_ ? static_cast<double>(num_merged_) :
99  mix_denominator_;
100  if (normalizer != 1.0) {
101  cur_message_.set_value(cur_message_.value() / normalizer);
102  cur_message_.set_avg_value(cur_message_.avg_value() / normalizer);
103  }
104  string encoded_msg;
105  messageio_.EncodeBase64(cur_message_, &encoded_msg);
106  cout << prev_feat_.c_str() << "\t" << encoded_msg.c_str();
107  prev_feat_.clear();
108  cur_message_.Clear();
109  num_merged_ = 0;
110  return 1;
111  }
112  return 0;
113 }
114 
115 int ModelInfoReducer::Reduce(const string& key, const string& value) {
116  ModelMessage new_message;
117  if (!messageio_.DecodeBase64(value, &new_message)) {
118  cerr << "Error decoding message: " << value.c_str() << endl;
119  }
120  if (new_model_message_) {
121  model_message_.CopyFrom(new_message);
122  new_model_message_ = false;
123  if (model_message_.has_symbols()) {
124  model_message_.clear_symbols();
125  }
126  } else {
127  model_message_.set_loss(model_message_.loss() + new_message.loss());
128  model_message_.set_training_errors(
129  model_message_.training_errors() + new_message.training_errors());
130  if (model_message_.reader_spec().compare(new_message.reader_spec()) != 0) {
131  cerr << "Combining messages with different reader_spec fields.";
132  return -1;
133  }
134  // Check that the models being merged have the same specs.
135  if (model_message_.model_spec().compare(new_message.model_spec()) != 0) {
136  cerr << "Combining messages with different model_spec fields.";
137  return -1;
138  }
139  if (model_message_.identifier().compare(new_message.identifier()) != 0) {
140  cerr << "Combining messages with different identifier fields.";
141  return -1;
142  }
143  if (model_message_.num_iterations() != new_message.num_iterations()) {
144  cerr << "Combining messages with different num_iterations fields.";
145  return -1;
146  }
147  if (model_message_.has_symbols()) {
148  if (new_message.has_symbols()) {
149  // Do something sensible to merge symbols ???
150  // Or assume they are the same.
151  // TODO(kbhall): resolve this symbols problem.
152  }
153  }
154  }
155  return 0;
156 }
157 
159  if (new_model_message_) {
160  return 0;
161  }
162  if (!model_message_.has_num_iterations()) {
163  cerr << "No model information";
164  return -1;
165  }
166  model_message_.set_num_iterations(model_message_.num_iterations() + 1);
167  string encoded_msg;
168  messageio_.EncodeBase64(model_message_, &encoded_msg);
170  cout << "\t" << encoded_msg.c_str();
171  model_message_.Clear();
172  return 1;
173 }
174 
175 int SymbolReducer::Reduce(const string& key, const string& value) {
176  if (key.compare(prev_sym_) != 0) {
177  cout << key.c_str() << endl;
178  prev_sym_ = key;
179  return 1;
180  }
181  return 0;
182 }
183 
184 } // End of namespace.
virtual int Reduce(const string &key, const string &value)
FeatureReducer(bool uniform_mix, double mix_denominator)
Reducer classes for trainer.
virtual int Reduce(const string &feat_id, const string &value)
static const char * kModelMessageFeatureName
virtual int Reduce(const string &key, const string &value)