![]() |
User Manual, Developers Guide and API Documentation |
![]() |
00001 /******************************************************************************* 00002 * This file is part of openWNS (open Wireless Network Simulator) 00003 * _____________________________________________________________________________ 00004 * 00005 * Copyright (C) 2004-2007 00006 * Chair of Communication Networks (ComNets) 00007 * Kopernikusstr. 16, D-52074 Aachen, Germany 00008 * phone: ++49-241-80-27910, 00009 * fax: ++49-241-80-22242 00010 * email: info@openwns.org 00011 * www: http://www.openwns.org 00012 * _____________________________________________________________________________ 00013 * 00014 * openWNS is free software; you can redistribute it and/or modify it under the 00015 * terms of the GNU Lesser General Public License version 2 as published by the 00016 * Free Software Foundation; 00017 * 00018 * openWNS is distributed in the hope that it will be useful, but WITHOUT ANY 00019 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 00020 * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 00021 * details. 00022 * 00023 * You should have received a copy of the GNU Lesser General Public License 00024 * along with this program. If not, see <http://www.gnu.org/licenses/>. 00025 * 00026 ******************************************************************************/ 00027 #include <CONSTANZE/BaumWelch.hpp> 00028 00029 using namespace constanze; 00030 00031 BaumWelch::BaumWelch() 00032 { 00033 } 00034 00035 BaumWelch::~BaumWelch() 00036 { 00037 } 00038 00039 HMM* 00040 BaumWelch::baumWelch(HMM *initialHMM, std::vector<int> *observationVector, int iterations) 00041 { 00042 int numberOfStates = initialHMM->getNumberOfStates(); 00043 00044 for(int iteration = 0; iteration < iterations; iteration++) { 00045 00046 // execution in every iteration forward and backward procedure 00047 std::vector<std::vector<baumWelchDataType>*> *forwardMatrix = forward(initialHMM, observationVector); 00048 std::vector<std::vector<baumWelchDataType>*> *backwardMatrix = backward(initialHMM, observationVector); 00049 00050 // calculate new start probability 00051 std::vector<baumWelchDataType> *startStateProbability = new std::vector<baumWelchDataType>(); 00052 00053 for(int i = 0; i < numberOfStates; i++){ 00054 startStateProbability->push_back(calculateGamma(i,0,forwardMatrix,backwardMatrix,initialHMM)); 00055 } 00056 00057 // calculate new transition matrix 00058 std::vector<std::vector<baumWelchDataType>*> *transitionsMatrix = new std::vector<std::vector<baumWelchDataType>*>(); 00059 for(int i = 0; i < numberOfStates; i++){ 00060 transitionsMatrix->push_back(new std::vector<baumWelchDataType>()); 00061 } 00062 00063 for(int i = 0; i < numberOfStates; i++){ 00064 for(int j = 0; j < numberOfStates; j++){ 00065 baumWelchDataType num = 0.0; 00066 baumWelchDataType denom = 0.0; 00067 00068 for(unsigned int t = 0; t < observationVector->size()-1; t++){ 00069 num += calculateXi(t,i,j,observationVector,forwardMatrix,backwardMatrix,initialHMM); 00070 denom += calculateGamma(i,t,forwardMatrix,backwardMatrix,initialHMM); 00071 } 00072 transitionsMatrix->at(i)->push_back(divide(num,denom)); 00073 } 00074 } 00075 00080 std::vector<std::vector<baumWelchDataType>*> *observationMatrix = new std::vector<std::vector<baumWelchDataType>*>(); 00081 00082 for(int i = 0; i < numberOfStates; i++){ 00083 observationMatrix->push_back(new std::vector<baumWelchDataType>()); 00084 } 00085 00086 for(int i = 0; i < numberOfStates; i++){ 00087 for(int j = 0; j < numberOfStates; j++){ 00088 if(i == j) 00089 observationMatrix->at(i)->push_back(1.0); 00090 else 00091 observationMatrix->at(i)->push_back(0.0); 00092 } 00093 } 00094 00095 // set memory for hmm free 00096 delete initialHMM; 00097 00098 initialHMM = new HMM(numberOfStates,transitionsMatrix,observationMatrix,startStateProbability); 00099 00100 for(int i = 0; i < numberOfStates; i++){ 00101 delete forwardMatrix->at(i); 00102 delete backwardMatrix->at(i); 00103 } 00104 00105 delete forwardMatrix; 00106 delete backwardMatrix; 00107 00108 } // for iteration 00109 00114 return initialHMM; 00115 00116 } 00117 00118 std::vector<std::vector<baumWelchDataType>*>* 00119 BaumWelch::forward(HMM *initialHMM, std::vector<int> *observationVector) 00120 { 00125 std::vector<std::vector<baumWelchDataType>*> *forwardMatrix = new std::vector<std::vector<baumWelchDataType>*>(); 00126 00127 for(int k = 0; k < initialHMM->getNumberOfStates(); k++){ 00128 forwardMatrix->push_back(new std::vector<baumWelchDataType>()); 00129 } 00130 00139 for(int i = 0; i < initialHMM->getNumberOfStates(); i++){ 00140 forwardMatrix->at(i)->push_back(initialHMM->getStartStateProbability(i) * initialHMM->getElementInObservationMatrix(i,observationVector->at(0))); 00141 } 00142 00148 for(unsigned int t = 1; t < observationVector->size(); t++){ 00149 for(int j = 0; j < initialHMM->getNumberOfStates(); j++){ 00150 baumWelchDataType sum = 0.0; 00151 for(int i = 0; i < initialHMM->getNumberOfStates(); i++){ 00152 sum += forwardMatrix->at(i)->at(t-1) * initialHMM->getElementInTransitionsMatrix(i,j); 00153 } 00154 forwardMatrix->at(j)->push_back(sum * initialHMM->getElementInObservationMatrix(j,observationVector->at(t))); 00155 } 00156 } 00157 00158 return forwardMatrix; 00159 } 00160 00161 std::vector<std::vector<baumWelchDataType>*>* 00162 BaumWelch::backward(HMM *initialHMM, std::vector<int> *observationVector) 00163 { 00164 00169 std::vector<std::vector<baumWelchDataType>*> *backwardMatrix = new std::vector<std::vector<baumWelchDataType>*>(); 00170 00175 std::vector<std::vector<baumWelchDataType>*> *helpMatrix = new std::vector<std::vector<baumWelchDataType>*>(); 00176 00177 for(int k = 0; k < initialHMM->getNumberOfStates(); k++){ 00178 backwardMatrix->push_back(new std::vector<baumWelchDataType>()); 00179 helpMatrix->push_back(new std::vector<baumWelchDataType>()); 00180 } 00181 00189 for(int i = 0; i < initialHMM->getNumberOfStates(); i++){ 00190 helpMatrix->at(i)->push_back(1.0); 00191 } 00192 00197 for(int t = observationVector->size()-2; t >= 0; t--){ 00198 for(int i = 0; i < initialHMM->getNumberOfStates(); i++){ 00199 baumWelchDataType sum = 0; 00200 for(int j = 0; j < initialHMM->getNumberOfStates(); j++){ 00201 sum += helpMatrix->at(j)->at(observationVector->size()-2-t) * initialHMM->getElementInTransitionsMatrix(i,j) * initialHMM->getElementInObservationMatrix(j,observationVector->at(t+1)); 00202 } 00203 helpMatrix->at(i)->push_back(sum); 00204 } 00205 } 00206 00207 // reverse helpMatrix to achieve backward matrix 00208 for(int i = 0; i < initialHMM->getNumberOfStates(); i++){ 00209 for(unsigned int j = 0; j < observationVector->size(); j++){ 00210 backwardMatrix->at(i)->push_back(helpMatrix->at(i)->at(observationVector->size()-1-j)); 00211 } 00212 } 00213 00214 // set memory free 00215 for(int i = 0; i < initialHMM->getNumberOfStates(); i++){ 00216 delete helpMatrix->at(i); 00217 } 00218 00219 delete helpMatrix; 00220 00221 return backwardMatrix; 00222 } 00223 00228 baumWelchDataType 00229 BaumWelch::calculateGamma(int i,int t,std::vector<std::vector<baumWelchDataType>*> *forwardMatrix,std::vector<std::vector<baumWelchDataType>*> *backwardMatrix,HMM *h) 00230 { 00231 00232 // num = alpha_t(i) * beta_t(i) 00233 baumWelchDataType num = forwardMatrix->at(i)->at(t) * backwardMatrix->at(i)->at(t); 00234 baumWelchDataType denom = 0.0; 00235 00236 // denom = sum_{alpha_t(j) * beta_t(j)} 00237 for(int j = 0; j < h->getNumberOfStates(); j++) 00238 denom += forwardMatrix->at(j)->at(t) * backwardMatrix->at(j)->at(t); 00239 return divide(num,denom); 00240 00241 } 00242 00247 baumWelchDataType 00248 BaumWelch::calculateXi(int t,int i,int j,std::vector<int> *observationVector,std::vector<std::vector<baumWelchDataType>*> *forwardMatrix,std::vector<std::vector<baumWelchDataType>*> *backwardMatrix,HMM *h) 00249 { 00250 00251 baumWelchDataType num; 00252 baumWelchDataType denom = 0.0; 00253 00254 // num = alpha_t(i) * a_ij * b_j(O_t+1) * beta_t+1(j) 00255 num = forwardMatrix->at(i)->at(t) * h->getElementInTransitionsMatrix(i,j) * h->getElementInObservationMatrix(j,observationVector->at(t+1)) * backwardMatrix->at(j)->at(t+1); 00256 00257 // denom = sum_{sum_{num}} 00258 for(int j = 0; j < h->getNumberOfStates(); j++){ 00259 for(int i = 0; i < h->getNumberOfStates(); i++){ 00260 denom += forwardMatrix->at(i)->at(t) * h->getElementInTransitionsMatrix(i,j) * h->getElementInObservationMatrix(j,observationVector->at(t+1)) * backwardMatrix->at(j)->at(t+1); 00261 } 00262 } 00263 // try { result=num/denom; } catch ... 00264 return divide(num,denom); 00265 } 00266 00267 baumWelchDataType 00268 BaumWelch::divide(baumWelchDataType num, baumWelchDataType denom) 00269 { 00270 if(denom == 0.0) 00271 return 0.0; 00272 else 00273 return num/denom; 00274 } 00275 00276 HMM* 00277 BaumWelch::getHMM(){ 00278 return hmm; 00279 } 00280
1.5.5