41 #include "../proto/model.pb.h"
49 using confusion_learning::FeatureMessage;
50 using confusion_learning::SymbolTableMessage;
51 using confusion_learning::SymbolMessage;
57 perceptron_model->
name_ = model_message.identifier();
62 if (perceptron_model->
symbols_ != NULL && model_message.has_symbols()) {
63 const SymbolTableMessage &symbol_table_message = model_message.symbols();
64 for (
int i = 0; i < symbol_table_message.symbol_size(); ++i) {
65 const SymbolMessage &symbol_message = symbol_table_message.symbol(i);
67 symbol_message.index());
71 if (model_message.has_raw_parameters()) {
72 fv_reader_.Read(model_message.raw_parameters(),
76 if (model_message.has_avg_parameters()) {
77 fv_reader_.Read(model_message.avg_parameters(),
101 const string& separator)
const {
105 ConfusionProtoIO proto_reader;
107 while (is && is.good()) {
109 if (buffer.empty()) {
113 size_t seppos = buffer.find(separator);
114 if (seppos != string::npos) {
115 buffer.erase(0, seppos+1);
118 FeatureMessage feature_msg;
119 if (!proto_reader.DecodeBase64(buffer, &feature_msg)) {
120 cerr <<
"Error decoding: " << feature_msg.Utf8DebugString() << endl;
123 int uid = feature_msg.id();
124 if (symbols != NULL &&
125 feature_msg.has_name() && !feature_msg.name().empty()) {
126 uid = symbols->
GetIndex(feature_msg.name());
128 double value = feature_msg.value();
129 if (std::isnan(value)) {
130 cerr <<
"PerceptronModelProtoReader: WARNING: feature "
131 << uid <<
" has value that is NaN" << endl;
135 if (feature_msg.has_avg_value()) {
136 double avg_value = feature_msg.avg_value();
137 if (std::isnan(avg_value)) {
138 cerr <<
"PerceptronModelProtoReader: WARNING: feature "
139 << uid <<
" has avg_value that is NaN" << endl;
147 if (features.weights_.
size() == 0 && features.average_weights_.
size() > 0) {
148 features.weights_ = features.average_weights_;
149 }
else if (features.average_weights_.
size() == 0 &&
150 features.weights_.
size() > 0) {
151 features.average_weights_ = features.weights_;
155 perceptron_model->
models_ = features;
size_t size() const
Returns the number of non-zero feature components of this feature vector.
Model is an interface for reranking models.
#define REGISTER_MODEL_PROTO_READER(TYPE)
Registers the ModelProtoReader implementation with the specified subtype TYPE with the ModelProtoRea...
TrainingVectorSet best_models_
The best models seen so far during training, according to evaluation on the held-out development test...
virtual void ReadFeatures(istream &is, Model *model, bool skip_key, const string &separator) const
De-serializes Features from an instance.
A simple class to hold the three notions of time during training: the current epoch, the current time index within the current epoch, and the absolute time index.
Symbols * symbols() const
Returns the symbol table for this model.
De-serializer for reranker::PerceptronModel instances from ModelMessage instances.
TrainingVectorSet models_
The feature vectors representing this model.
This class implements a perceptron model reranker.
Symbols * symbols_
The symbol table for this model (may be NULL).
A class to construct a PerceptronModel from a ModelMessage instance.
Time time_
The tiny object that holds the "training time" for this model (epoch, index and absolute time index)...
virtual int GetIndex(const string &symbol)=0
Converts the specified symbol to a unique integer.
An interface specifying a converter from symbols (strings) to int indices.
A class to hold the several feature vectors needed during training (especially for the perceptron fam...
string name_
This model’s unique name.
V IncrementWeight(const K &uid, V by)
Increments the weight of the specified feature by the specified amount.
int best_model_epoch_
The epoch of the best models seen so far during training.
virtual void SetIndex(const string &symbol, int index)=0
Provides the reranker::TrainingVectorSet class.