Reranker Framework (ReFr)
Reranking framework for structure prediction and discriminative language modeling
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
model.C
Go to the documentation of this file.
1 // Copyright 2012, Google Inc.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Google Inc. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 // -----------------------------------------------------------------------------
30 //
31 //
37 
38 #include <cstdlib>
39 #include <iostream>
40 #include <sstream>
41 #include <stdexcept>
42 
43 #include "tokenizer.H"
44 #include "model.H"
45 
46 using namespace std;
47 
48 namespace reranker {
49 
50 IMPLEMENT_FACTORY(Model)
51 IMPLEMENT_FACTORY(Model::UpdatePredicate)
52 IMPLEMENT_FACTORY(Model::Updater)
53 
56 
59 
60 void
62  CandidateSet &candidates, bool training) {
63  // N.B.: We assume there is at least one candidate!
64  CandidateSet::iterator it = candidates.begin();
65 
66  // Score first candidate, which is perforce the best candidate so far.
67  model->ScoreCandidate(*(*it), training);
68 
69  CandidateSet::iterator best_it = it;
70  CandidateSet::iterator gold_it = it;
71  ++it;
72 
73  // Score any and all remaining candidates.
74  for ( ; it != candidates.end(); ++it) {
75  Candidate &candidate = *(*it);
76  model->ScoreCandidate(candidate, training);
77  if (model->score_comparator()->Compare(*model, candidate, **best_it) > 0) {
78  best_it = it;
79  }
80  if (model->gold_comparator()->Compare(*model, candidate, **gold_it) > 0) {
81  gold_it = it;
82  }
83  }
84  candidates.set_best_scoring_index((*best_it)->index());
85  candidates.set_gold_index((*gold_it)->index());
86 }
87 
88 void
89 RandomPairCandidateSetScorer::Init(const Environment *env,
90  const string &arg) {
91  srand(time(NULL));
92 }
93 
94 size_t
95 RandomPairCandidateSetScorer::GetRandomIndex(size_t max) {
96  // Get an index proportional to the reciprocal rank.
97  // First, get a floating-point value that is distributed uniformly (roughly)
99  double r = rand() / (double)RAND_MAX;
100  // We compute the cummulative density function, or cdf, of the
101  // reciprocal rank distribution over the fixed set of items. As soon as
102  // our uniformly-distributed r is less than the cdf at index i, we
103  // return that index.
104  double cdf = 0.0;
105  double denominator = (max * (max + 1)) / 2;
106  for (size_t i = 0; i < max; ++i) {
107  cdf += (max - i) / denominator;
108  if (r <= cdf) {
109  return i;
110  }
111  }
112  return max - 1;
113 }
114 
115 void
116 RandomPairCandidateSetScorer::Score(Model *model,
117  CandidateSet &candidates, bool training) {
118  // First, pick two candidate indices at random.
119  size_t idx1 = GetRandomIndex(candidates.size());
120  size_t idx2 = GetRandomIndex(candidates.size());
121  Candidate &c1 = candidates.Get(idx1);
122  Candidate &c2 = candidates.Get(idx2);
123 
124  // Next, just score those two candidates.
125  model->ScoreCandidate(c1, training);
126  model->ScoreCandidate(c2, training);
127 
128  // Finally, set indices of best scoring and gold amongst just those two.
129  int score_cmp = model->score_comparator()->Compare(*model, c1, c2);
130  candidates.set_best_scoring_index(score_cmp > 0 ? c1.index() : c2.index());
131 
132  int gold_cmp = model->gold_comparator()->Compare(*model, c1, c2);
133  candidates.set_gold_index(gold_cmp > 0 ? c1.index() : c2.index());
134 }
135 
136 void
137 Model::CheckNumberOfTokens(const string &arg,
138  const vector<string> &tokens,
139  size_t min_expected_number,
140  size_t max_expected_number,
141  const string &class_name) const {
142  if ((min_expected_number > 0 && tokens.size() < min_expected_number) ||
143  (max_expected_number > 0 && tokens.size() > max_expected_number)) {
144  std::stringstream err_ss;
145  err_ss << class_name << "::Init: error parsing init string \""
146  << arg << "\": expected between "
147  << min_expected_number << " and " << max_expected_number
148  << " tokens but found " << tokens.size() << " tokens";
149  cerr << err_ss.str() << endl;
150  throw std::runtime_error(err_ss.str());
151  }
152 }
153 
154 shared_ptr<Candidate::Comparator>
155 Model::GetComparator(const string &spec) const {
156  Factory<Candidate::Comparator> comparator_factory;
157  string err_msg = "error: model " + name() + ": could not construct " +
158  "Candidate::Comparator from specification string \"" + spec + "\"";
159  return comparator_factory.CreateOrDie(spec, err_msg);
160 }
161 
162 shared_ptr<CandidateSet::Scorer>
163 Model::GetCandidateSetScorer(const string &spec) const {
164  Factory<CandidateSet::Scorer> candidate_set_scorer_factory;
165  string err_msg = "error: model " + name() + ": could not construct " +
166  "Candidate::Scorer from specification string \"" + spec + "\"";
167  return candidate_set_scorer_factory.CreateOrDie(spec, err_msg);
168 }
169 
170 shared_ptr<Model::UpdatePredicate>
171 Model::GetUpdatePredicate(const string &spec) const {
172  Factory<Model::UpdatePredicate> update_predicate_factory;
173  string err_msg = "error: model " + name() + ": could not construct " +
174  "Model::UpdatePredicate from specification string \"" + spec + "\"";
175  return update_predicate_factory.CreateOrDie(spec, err_msg);
176 }
177 
178 shared_ptr<Model::Updater>
179 Model::GetUpdater(const string &spec) const {
180  Factory<Model::Updater> updater_factory;
181  string err_msg = "error: model " + name() + ": could not construct " +
182  "Model::Updater from specification string \"" + spec + "\"";
183  return updater_factory.CreateOrDie(spec, err_msg);
184 }
185 
186 } // namespace reranker
Model is an interface for reranking models.
Definition: model.H:141
#define REGISTER_CANDIDATE_COMPARATOR(TYPE)
Definition: candidate.H:284
The default comparator for comparing two Candidate instances based on their respective scores (i...
Definition: model.H:60
vector< shared_ptr< Candidate > >::iterator iterator
Definition: candidate-set.H:73
Provides the Tokenizer class.
The default comparator for comparing two Candidate instances for being the “gold” candidate...
Definition: model.H:75
virtual shared_ptr< Candidate::Comparator > score_comparator()
Returns a pointer to the score comparator used by this model.
Definition: model.H:334
Factory for dynamically created instance of the specified type.
Definition: factory.H:396
int index() const
Returns the index of this candidate relative to the other candidates.
Definition: candidate.H:121
This candidate set scorer picks two candidates at random from the set, scores them and then identifie...
Definition: model.H:119
size_t size() const
Definition: candidate-set.H:94
void set_gold_index(size_t index)
A class to hold a set of candidates, either for training or test.
Definition: candidate-set.H:62
shared_ptr< T > CreateOrDie(StreamTokenizer &st, Environment *env=NULL)
Dynamically creates an object, whose type and initialization are contained in a specification string...
Definition: factory.H:562
An interface for an environment in which variables of various types are mapped to their values...
Definition: environment.H:125
A class to represent a candidate in a set of candidates that constitutes a training instance for a re...
Definition: candidate.H:60
virtual shared_ptr< Candidate::Comparator > gold_comparator()
Returns a pointer to the gold comparator used by this model.
Definition: model.H:339
#define IMPLEMENT_FACTORY(BASE)
Provides the necessary implementation for a factory for the specified BASE class type.
Definition: factory.H:821
#define REGISTER_CANDIDATE_SET_SCORER(TYPE)
Candidate & Get(size_t idx)
The default candidate set scorer scores each candidate using the Model::ScoreCandidate method and the...
Definition: model.H:109
void set_best_scoring_index(size_t index)
Reranker model interface.
virtual double ScoreCandidate(Candidate &candidate, bool training)=0
Scores a candidate according to either the raw or averaged version of this perceptron model...