00001 /******************************************************************************* 00002 Copyright (c) 2011, Yahoo! Inc. 00003 All rights reserved. 00004 00005 Redistribution and use of this software in source and binary forms, 00006 with or without modification, are permitted provided that the following 00007 conditions are met: 00008 00009 * Redistributions of source code must retain the above 00010 copyright notice, this list of conditions and the 00011 following disclaimer. 00012 00013 * Redistributions in binary form must reproduce the above 00014 copyright notice, this list of conditions and the 00015 following disclaimer in the documentation and/or other 00016 materials provided with the distribution. 00017 00018 * Neither the name of Yahoo! Inc. nor the names of its 00019 contributors may be used to endorse or promote products 00020 derived from this software without specific prior 00021 written permission of Yahoo! Inc. 00022 00023 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS 00024 IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED 00025 TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 00026 PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 00027 OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 00028 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 00029 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 00030 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 00031 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 00032 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 00033 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 00034 00035 The Initial Developer of the Original Code is Shravan Narayanamurthy. 00036 ******************************************************************************/ 00037 /* 00038 * Unigram_Model_Trainer.h 00039 * 00040 * Created on: 06-Jan-2011 00041 * 00042 */ 00043 00044 #ifndef UNIGRAM_MODEL_TRAINER_H_ 00045 #define UNIGRAM_MODEL_TRAINER_H_ 00046 00047 #include "TopicLearner/Model_Refiner.h" 00048 #include "TypeTopicCounts.h" 00049 #include "TopicLearner/Parameter.h" 00050 #include "DocumentReader.h" 00051 #include "DocumentWriter.h" 00052 #include <boost/random/variate_generator.hpp> 00053 #include <boost/random/uniform_real.hpp> 00054 00055 using namespace boost; 00056 using namespace std; 00057 00058 /** 00059 * The default implementation of Model_Refiner for the 00060 * Unigram model 00061 */ 00062 class Unigram_Model_Trainer: public Model_Refiner { 00063 public: 00064 Unigram_Model_Trainer(TypeTopicCounts&, Parameter&, Parameter&); 00065 virtual ~Unigram_Model_Trainer(); 00066 00067 google::protobuf::Message* allocate_document_buffer(size_t); 00068 void deallocate_document_buffer(google::protobuf::Message*); 00069 google::protobuf::Message* get_nth_document( 00070 google::protobuf::Message* docs, size_t n); 00071 00072 //!Reads a document from the protobuf 00073 //!format word & topic files using DocumentReader 00074 void* read(google::protobuf::Message&); 00075 00076 //!Does Gibbs sampling using sampler.cpp to 00077 //!figure out new topic assignments to each 00078 //!word present in the document passed in the msg 00079 void* sample(void*); 00080 00081 //!Takes a msg which contains the document to be 00082 //!processed and the updated topics for each word 00083 //!in the document as a vector. It then processes 00084 //!each update by just calling upd_count on the 00085 //!TypeTopicCounts object with the update details 00086 void* update(void*); 00087 00088 //!Performs stochastic GD to optimize the alphas. The 00089 //!gradients are accumulated for tau docs and then the 00090 //!global alphas are updated. 00091 void* optimize(void*); 00092 00093 //!Compute the document portion of the log-likelihood 00094 void* eval(void*, double&); 00095 00096 //!Takes the document and writes it to disk. Here we 00097 //!use a simple optimization of not writing the 00098 //!body/words in the document but only the topics. 00099 //!This is because the words in the document never 00100 //!change. Its only the topics that change. The documents 00101 //!are written using a DocumentWriter to disk 00102 void write(void*); 00103 00104 void iteration_done(); 00105 00106 void* test(void*); 00107 00108 static long doc_index; //!Running count of all the documents processed by the optimizer 00109 00110 private: 00111 void set_up_io(string, string, string); 00112 void release_io(); 00113 //!Sampler 00114 void do_one_doc(int num_words_in_doc, LDA::unigram_document *& doc, 00115 topicCounts & current_topic_counts, 00116 atomic<topic_t> *& tokens_per_topic, 00117 topic_t *& document_topic_counts, topic_t *& document_topic_index, 00118 int& non_zero_topics, double &Abar, double &Bbar, 00119 double *& C_cac_coeff, double *& topic_term_scores, 00120 vector<change_elem_t> *& updates); 00121 00122 void sample_topics(update_t* upd, vector<change_elem_t> *updates); 00123 00124 private: 00125 TypeTopicCounts& _ttc; 00126 Parameter& _alpha; 00127 Parameter& _beta; 00128 bool ignore_old_topic; 00129 int _num_words, _num_topics; 00130 //!Reader 00131 DocumentReader *_wdoc_rdr, *_tdoc_rdr; 00132 DocumentWriter *_tdoc_writer; 00133 00134 //!Sampler 00135 //!The structures needed to setup boost RNG 00136 //!Its a combination of a variate generator 00137 //!and distribution object. We create an 00138 //!array of RNGs since its hard to maintain 00139 //!thread local state in TBB. We choose an 00140 //!index into the array randomly and hope 00141 //!different threads land up at different 00142 //!RNGS 00143 base_generator_type *generators[NUM_RNGS]; 00144 uniform_real<> *uni_dists[NUM_RNGS]; 00145 variate_generator<base_generator_type&, boost::uniform_real<> > 00146 *unif01[NUM_RNGS]; 00147 long rng_ind; 00148 00149 //!Optimizer 00150 double *part_grads, //!Local Alphas into which we accumulate the gradients 00151 part_grads_top_indep; //!Local AlphaBar 00152 int tau; //!The number documents to accumulate gradients 00153 double eta; //!The fraction of the gradient to be merged into global alphas 00154 }; 00155 00156 #endif /* UNIGRAM_MODEL_TRAINER_H_ */