40 #include "../proto/dataio.h"
41 #include "../utils/kdebug.h"
42 #include "../proto/model.pb.h"
48 using confusion_learning::FeatureMessage;
49 using confusion_learning::ModelMessage;
54 : num_merged_(0), uniform_mix_(uniform_mix),
55 mix_denominator_(mix_denominator) {
60 FeatureMessage new_message;
61 if (!messageio_.DecodeBase64(value, &new_message)) {
62 cerr <<
"Error decoding message: " << value.c_str() << endl;
65 if (feat_id.compare(prev_feat_) != 0) {
67 if (!prev_feat_.empty()) {
70 double normalizer = uniform_mix_ ?
static_cast<double>(num_merged_) :
72 if (normalizer != 1.0) {
73 cur_message_.set_value(cur_message_.value() / normalizer);
74 cur_message_.set_avg_value(cur_message_.avg_value() / normalizer);
78 messageio_.EncodeBase64(cur_message_, &encoded_msg);
79 cout << prev_feat_.c_str() <<
"\t" << encoded_msg.c_str();
83 cur_message_.CopyFrom(new_message);
87 cur_message_.set_value(cur_message_.value() + new_message.value());
88 cur_message_.set_avg_value(cur_message_.avg_value() + new_message.avg_value());
89 cur_message_.set_count(cur_message_.count() + new_message.count());
97 if (!prev_feat_.empty()) {
98 double normalizer = uniform_mix_ ?
static_cast<double>(num_merged_) :
100 if (normalizer != 1.0) {
101 cur_message_.set_value(cur_message_.value() / normalizer);
102 cur_message_.set_avg_value(cur_message_.avg_value() / normalizer);
105 messageio_.EncodeBase64(cur_message_, &encoded_msg);
106 cout << prev_feat_.c_str() <<
"\t" << encoded_msg.c_str();
108 cur_message_.Clear();
116 ModelMessage new_message;
117 if (!messageio_.DecodeBase64(value, &new_message)) {
118 cerr <<
"Error decoding message: " << value.c_str() << endl;
120 if (new_model_message_) {
121 model_message_.CopyFrom(new_message);
122 new_model_message_ =
false;
123 if (model_message_.has_symbols()) {
124 model_message_.clear_symbols();
127 model_message_.set_loss(model_message_.loss() + new_message.loss());
128 model_message_.set_training_errors(
129 model_message_.training_errors() + new_message.training_errors());
130 if (model_message_.reader_spec().compare(new_message.reader_spec()) != 0) {
131 cerr <<
"Combining messages with different reader_spec fields.";
135 if (model_message_.model_spec().compare(new_message.model_spec()) != 0) {
136 cerr <<
"Combining messages with different model_spec fields.";
139 if (model_message_.identifier().compare(new_message.identifier()) != 0) {
140 cerr <<
"Combining messages with different identifier fields.";
143 if (model_message_.num_iterations() != new_message.num_iterations()) {
144 cerr <<
"Combining messages with different num_iterations fields.";
147 if (model_message_.has_symbols()) {
148 if (new_message.has_symbols()) {
159 if (new_model_message_) {
162 if (!model_message_.has_num_iterations()) {
163 cerr <<
"No model information";
166 model_message_.set_num_iterations(model_message_.num_iterations() + 1);
168 messageio_.EncodeBase64(model_message_, &encoded_msg);
170 cout <<
"\t" << encoded_msg.c_str();
171 model_message_.Clear();
176 if (key.compare(prev_sym_) != 0) {
177 cout << key.c_str() << endl;
virtual int Reduce(const string &key, const string &value)
FeatureReducer(bool uniform_mix, double mix_denominator)
Reducer classes for trainer.
virtual int Reduce(const string &feat_id, const string &value)
static const char * kModelMessageFeatureName
virtual int Reduce(const string &key, const string &value)