User Manual, Developers Guide and API Documentation

BaumWelch.cpp

Go to the documentation of this file.
00001 /*******************************************************************************
00002  * This file is part of openWNS (open Wireless Network Simulator)
00003  * _____________________________________________________________________________
00004  *
00005  * Copyright (C) 2004-2007
00006  * Chair of Communication Networks (ComNets)
00007  * Kopernikusstr. 16, D-52074 Aachen, Germany
00008  * phone: ++49-241-80-27910,
00009  * fax: ++49-241-80-22242
00010  * email: info@openwns.org
00011  * www: http://www.openwns.org
00012  * _____________________________________________________________________________
00013  *
00014  * openWNS is free software; you can redistribute it and/or modify it under the
00015  * terms of the GNU Lesser General Public License version 2 as published by the
00016  * Free Software Foundation;
00017  *
00018  * openWNS is distributed in the hope that it will be useful, but WITHOUT ANY
00019  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
00020  * A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
00021  * details.
00022  *
00023  * You should have received a copy of the GNU Lesser General Public License
00024  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
00025  *
00026  ******************************************************************************/
00027 #include <CONSTANZE/BaumWelch.hpp>
00028 
00029 using namespace constanze;
00030 
00031 BaumWelch::BaumWelch()
00032 {
00033 }
00034 
00035 BaumWelch::~BaumWelch()
00036 {
00037 }
00038 
00039 HMM*
00040 BaumWelch::baumWelch(HMM *initialHMM, std::vector<int> *observationVector, int iterations)
00041 {
00042     int numberOfStates = initialHMM->getNumberOfStates();
00043 
00044     for(int iteration = 0; iteration < iterations; iteration++) {
00045 
00046         // execution in every iteration forward and backward procedure
00047         std::vector<std::vector<baumWelchDataType>*> *forwardMatrix = forward(initialHMM, observationVector);
00048         std::vector<std::vector<baumWelchDataType>*> *backwardMatrix = backward(initialHMM, observationVector);
00049 
00050         // calculate new start probability
00051         std::vector<baumWelchDataType> *startStateProbability = new std::vector<baumWelchDataType>();
00052 
00053         for(int i = 0; i < numberOfStates; i++){
00054             startStateProbability->push_back(calculateGamma(i,0,forwardMatrix,backwardMatrix,initialHMM));
00055         }
00056 
00057         // calculate new transition matrix
00058         std::vector<std::vector<baumWelchDataType>*> *transitionsMatrix = new std::vector<std::vector<baumWelchDataType>*>();
00059         for(int i = 0; i < numberOfStates; i++){
00060             transitionsMatrix->push_back(new std::vector<baumWelchDataType>());
00061         }
00062 
00063         for(int i = 0; i < numberOfStates; i++){
00064             for(int  j = 0; j < numberOfStates; j++){
00065                 baumWelchDataType num = 0.0;
00066                 baumWelchDataType denom = 0.0;
00067 
00068                 for(unsigned int t = 0; t < observationVector->size()-1; t++){
00069                     num += calculateXi(t,i,j,observationVector,forwardMatrix,backwardMatrix,initialHMM);
00070                     denom += calculateGamma(i,t,forwardMatrix,backwardMatrix,initialHMM);
00071                 }
00072                 transitionsMatrix->at(i)->push_back(divide(num,denom));
00073             }
00074         }
00075 
00080         std::vector<std::vector<baumWelchDataType>*> *observationMatrix = new std::vector<std::vector<baumWelchDataType>*>();
00081 
00082         for(int i = 0; i < numberOfStates; i++){
00083             observationMatrix->push_back(new std::vector<baumWelchDataType>());
00084         }
00085 
00086         for(int i = 0; i < numberOfStates; i++){
00087             for(int j = 0; j < numberOfStates; j++){
00088                 if(i == j)
00089                     observationMatrix->at(i)->push_back(1.0);
00090                 else
00091                     observationMatrix->at(i)->push_back(0.0);
00092             }
00093         }
00094 
00095         // set memory for hmm free
00096         delete initialHMM;
00097 
00098         initialHMM = new HMM(numberOfStates,transitionsMatrix,observationMatrix,startStateProbability);
00099 
00100         for(int i = 0; i < numberOfStates; i++){
00101             delete forwardMatrix->at(i);
00102             delete backwardMatrix->at(i);
00103         }
00104 
00105         delete forwardMatrix;
00106         delete backwardMatrix;
00107 
00108     } // for iteration
00109 
00114     return initialHMM;
00115 
00116 }
00117 
00118 std::vector<std::vector<baumWelchDataType>*>*
00119 BaumWelch::forward(HMM *initialHMM, std::vector<int> *observationVector)
00120 {
00125     std::vector<std::vector<baumWelchDataType>*> *forwardMatrix = new std::vector<std::vector<baumWelchDataType>*>();
00126 
00127     for(int k = 0; k < initialHMM->getNumberOfStates(); k++){
00128         forwardMatrix->push_back(new std::vector<baumWelchDataType>());
00129     }
00130 
00139     for(int i = 0; i < initialHMM->getNumberOfStates(); i++){
00140         forwardMatrix->at(i)->push_back(initialHMM->getStartStateProbability(i) * initialHMM->getElementInObservationMatrix(i,observationVector->at(0)));
00141     }
00142 
00148     for(unsigned int t = 1; t < observationVector->size(); t++){
00149         for(int j = 0; j < initialHMM->getNumberOfStates(); j++){
00150             baumWelchDataType sum = 0.0;
00151             for(int i = 0; i < initialHMM->getNumberOfStates(); i++){
00152                 sum += forwardMatrix->at(i)->at(t-1) * initialHMM->getElementInTransitionsMatrix(i,j);
00153             }
00154             forwardMatrix->at(j)->push_back(sum * initialHMM->getElementInObservationMatrix(j,observationVector->at(t)));
00155         }
00156     }
00157 
00158     return forwardMatrix;
00159 }
00160 
00161 std::vector<std::vector<baumWelchDataType>*>*
00162 BaumWelch::backward(HMM *initialHMM, std::vector<int> *observationVector)
00163 {
00164 
00169     std::vector<std::vector<baumWelchDataType>*> *backwardMatrix = new std::vector<std::vector<baumWelchDataType>*>();
00170 
00175     std::vector<std::vector<baumWelchDataType>*> *helpMatrix = new std::vector<std::vector<baumWelchDataType>*>();
00176 
00177     for(int k = 0; k < initialHMM->getNumberOfStates(); k++){
00178         backwardMatrix->push_back(new std::vector<baumWelchDataType>());
00179         helpMatrix->push_back(new std::vector<baumWelchDataType>());
00180     }
00181 
00189     for(int i = 0; i < initialHMM->getNumberOfStates(); i++){
00190         helpMatrix->at(i)->push_back(1.0);
00191     }
00192 
00197     for(int t = observationVector->size()-2; t >= 0; t--){
00198         for(int i = 0; i < initialHMM->getNumberOfStates(); i++){
00199             baumWelchDataType sum = 0;
00200             for(int j = 0; j < initialHMM->getNumberOfStates(); j++){
00201                 sum += helpMatrix->at(j)->at(observationVector->size()-2-t) * initialHMM->getElementInTransitionsMatrix(i,j) * initialHMM->getElementInObservationMatrix(j,observationVector->at(t+1));
00202             }
00203             helpMatrix->at(i)->push_back(sum);
00204         }
00205     }
00206 
00207     // reverse helpMatrix to achieve backward matrix
00208     for(int i = 0; i < initialHMM->getNumberOfStates(); i++){
00209         for(unsigned int j = 0; j < observationVector->size(); j++){
00210             backwardMatrix->at(i)->push_back(helpMatrix->at(i)->at(observationVector->size()-1-j));
00211         }
00212     }
00213 
00214     // set memory free
00215     for(int i = 0; i < initialHMM->getNumberOfStates(); i++){
00216         delete helpMatrix->at(i);
00217     }
00218 
00219     delete helpMatrix;
00220 
00221     return backwardMatrix;
00222 }
00223 
00228 baumWelchDataType
00229 BaumWelch::calculateGamma(int i,int t,std::vector<std::vector<baumWelchDataType>*> *forwardMatrix,std::vector<std::vector<baumWelchDataType>*> *backwardMatrix,HMM *h)
00230 {
00231 
00232     // num = alpha_t(i) * beta_t(i)
00233     baumWelchDataType num = forwardMatrix->at(i)->at(t) * backwardMatrix->at(i)->at(t);
00234     baumWelchDataType denom = 0.0;
00235 
00236     // denom = sum_{alpha_t(j) * beta_t(j)}
00237     for(int j = 0; j < h->getNumberOfStates(); j++)
00238         denom += forwardMatrix->at(j)->at(t) * backwardMatrix->at(j)->at(t);
00239     return divide(num,denom);
00240 
00241 }
00242 
00247 baumWelchDataType
00248 BaumWelch::calculateXi(int t,int i,int j,std::vector<int> *observationVector,std::vector<std::vector<baumWelchDataType>*> *forwardMatrix,std::vector<std::vector<baumWelchDataType>*> *backwardMatrix,HMM *h)
00249 {
00250 
00251     baumWelchDataType num;
00252     baumWelchDataType denom = 0.0;
00253 
00254     // num = alpha_t(i) * a_ij * b_j(O_t+1) * beta_t+1(j)
00255     num = forwardMatrix->at(i)->at(t) * h->getElementInTransitionsMatrix(i,j) * h->getElementInObservationMatrix(j,observationVector->at(t+1)) * backwardMatrix->at(j)->at(t+1);
00256 
00257     // denom = sum_{sum_{num}}
00258     for(int j = 0; j < h->getNumberOfStates(); j++){
00259         for(int i = 0; i < h->getNumberOfStates(); i++){
00260             denom += forwardMatrix->at(i)->at(t) * h->getElementInTransitionsMatrix(i,j) * h->getElementInObservationMatrix(j,observationVector->at(t+1)) * backwardMatrix->at(j)->at(t+1);
00261         }
00262     }
00263     // try { result=num/denom; } catch ...
00264     return divide(num,denom);
00265 }
00266 
00267 baumWelchDataType
00268 BaumWelch::divide(baumWelchDataType num, baumWelchDataType denom)
00269 {
00270     if(denom == 0.0)
00271         return 0.0;
00272     else
00273         return num/denom;
00274 }
00275 
00276 HMM*
00277 BaumWelch::getHMM(){
00278     return hmm;
00279 }
00280 

Generated on Mon May 21 03:32:20 2012 for openWNS by  doxygen 1.5.5