Reranker Framework (ReFr)
Reranking framework for structure prediction and discriminative language modeling
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
perceptron-model-proto-reader.C
Go to the documentation of this file.
1 // Copyright 2012, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 // -----------------------------------------------------------------------------
30 //
31 //
36 
37 #include <cmath>
38 #include <cstdio>
39 #include <iostream>
40 #include <stdlib.h>
41 #include "../proto/model.pb.h"
42 #include "training-vector-set.H"
44 
45 namespace reranker {
46 
48 
49 using confusion_learning::FeatureMessage;
50 using confusion_learning::SymbolTableMessage;
51 using confusion_learning::SymbolMessage;
52 
53 void
54 PerceptronModelProtoReader::Read(const ModelMessage &model_message,
55  Model *model) const {
56  PerceptronModel *perceptron_model = static_cast<PerceptronModel *>(model);
57  perceptron_model->name_ = model_message.identifier();
58  perceptron_model->best_model_epoch_ = model_message.num_iterations();
59  perceptron_model->time_ = Time(perceptron_model->best_model_epoch_, -1, -1);
60  // TODO(dbikel): Emit warning if model_message.has_symbols() returns true
61  // when perceptron_model->symbols_ is NULL?
62  if (perceptron_model->symbols_ != NULL && model_message.has_symbols()) {
63  const SymbolTableMessage &symbol_table_message = model_message.symbols();
64  for (int i = 0; i < symbol_table_message.symbol_size(); ++i) {
65  const SymbolMessage &symbol_message = symbol_table_message.symbol(i);
66  perceptron_model->symbols_->SetIndex(symbol_message.symbol(),
67  symbol_message.index());
68  }
69  }
70  // TODO(dbikel): De-serialize model loss.
71  if (model_message.has_raw_parameters()) {
72  fv_reader_.Read(model_message.raw_parameters(),
73  perceptron_model->best_models_.weights_,
74  perceptron_model->symbols());
75  }
76  if (model_message.has_avg_parameters()) {
77  fv_reader_.Read(model_message.avg_parameters(),
78  perceptron_model->best_models_.average_weights_,
79  perceptron_model->symbols());
80  }
81  // Do "smart copying".
82  if (smart_copy_) {
83  if (perceptron_model->best_models_.weights_.size() == 0 &&
84  perceptron_model->best_models_.average_weights_.size() > 0) {
85  perceptron_model->best_models_.weights_ =
86  perceptron_model->best_models_.average_weights_;
87  } else if (perceptron_model->best_models_.average_weights_.size() == 0 &&
88  perceptron_model->best_models_.weights_.size() > 0) {
89  perceptron_model->best_models_.average_weights_ =
90  perceptron_model->best_models_.weights_;
91  }
92  }
93 
94  // Finally, make sure best_models_ is copied to models_.
95  perceptron_model->models_ = perceptron_model->best_models_;
96 }
97 
99  Model *model,
100  bool skip_key,
101  const string& separator) const {
102  PerceptronModel *perceptron_model = dynamic_cast<PerceptronModel *>(model);
103  TrainingVectorSet &features = perceptron_model->best_models_;
104  Symbols *symbols = perceptron_model->symbols();
105  ConfusionProtoIO proto_reader;
106  string buffer;
107  while (is && is.good()) {
108  getline(is, buffer);
109  if (buffer.empty()) {
110  break;
111  }
112  if (skip_key) {
113  size_t seppos = buffer.find(separator);
114  if (seppos != string::npos) {
115  buffer.erase(0, seppos+1);
116  }
117  }
118  FeatureMessage feature_msg;
119  if (!proto_reader.DecodeBase64(buffer, &feature_msg)) {
120  cerr << "Error decoding: " << feature_msg.Utf8DebugString() << endl;
121  continue;
122  }
123  int uid = feature_msg.id();
124  if (symbols != NULL &&
125  feature_msg.has_name() && !feature_msg.name().empty()) {
126  uid = symbols->GetIndex(feature_msg.name());
127  }
128  double value = feature_msg.value();
129  if (std::isnan(value)) {
130  cerr << "PerceptronModelProtoReader: WARNING: feature "
131  << uid << " has value that is NaN" << endl;
132  } else {
133  features.weights_.IncrementWeight(uid, value);
134  }
135  if (feature_msg.has_avg_value()) {
136  double avg_value = feature_msg.avg_value();
137  if (std::isnan(avg_value)) {
138  cerr << "PerceptronModelProtoReader: WARNING: feature "
139  << uid << " has avg_value that is NaN" << endl;
140  } else {
141  features.average_weights_.IncrementWeight(uid, avg_value);
142  }
143  }
144  }
145  // Do "smart copying".
146  if (smart_copy_) {
147  if (features.weights_.size() == 0 && features.average_weights_.size() > 0) {
148  features.weights_ = features.average_weights_;
149  } else if (features.average_weights_.size() == 0 &&
150  features.weights_.size() > 0) {
151  features.average_weights_ = features.weights_;
152  }
153  }
154  // Make sure to copy latest model to models_.
155  perceptron_model->models_ = features;
156 }
157 
158 } // namespace reranker
size_t size() const
Returns the number of non-zero feature components of this feature vector.
Model is an interface for reranking models.
Definition: model.H:141
#define REGISTER_MODEL_PROTO_READER(TYPE)
Registers the ModelProtoReader implementation with the specified subtype TYPE with the ModelProtoRea...
TrainingVectorSet best_models_
The best models seen so far during training, according to evaluation on the held-out development test...
virtual void ReadFeatures(istream &is, Model *model, bool skip_key, const string &separator) const
De-serializes Features from an instance.
A simple class to hold the three notions of time during training: the current epoch, the current time index within the current epoch, and the absolute time index.
Definition: training-time.H:56
Symbols * symbols() const
Returns the symbol table for this model.
Definition: model.H:284
De-serializer for reranker::PerceptronModel instances from ModelMessage instances.
TrainingVectorSet models_
The feature vectors representing this model.
This class implements a perceptron model reranker.
Symbols * symbols_
The symbol table for this model (may be NULL).
Definition: model.H:557
A class to construct a PerceptronModel from a ModelMessage instance.
Time time_
The tiny object that holds the "training time" for this model (epoch, index and absolute time index)...
Definition: model.H:552
virtual int GetIndex(const string &symbol)=0
Converts the specified symbol to a unique integer.
An interface specifying a converter from symbols (strings) to int indices.
Definition: symbol-table.H:57
A class to hold the several feature vectors needed during training (especially for the perceptron fam...
string name_
This model’s unique name.
Definition: model.H:547
V IncrementWeight(const K &uid, V by)
Increments the weight of the specified feature by the specified amount.
int best_model_epoch_
The epoch of the best models seen so far during training.
virtual void SetIndex(const string &symbol, int index)=0
Provides the reranker::TrainingVectorSet class.