00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #ifndef CParticleFilterData_H
00029 #define CParticleFilterData_H
00030
00031 #include <mrpt/utils/utils_defs.h>
00032 #include <mrpt/bayes/CProbabilityParticle.h>
00033
00034 #include <algorithm>
00035
00036 namespace mrpt
00037 {
00038 namespace bayes
00039 {
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 template <class T>
00056 class CParticleFilterData
00057 {
00058 public:
00059 typedef T CParticleDataContent;
00060 typedef CProbabilityParticle<T> CParticleData;
00061 typedef std::deque<CParticleData> CParticleList;
00062
00063 CParticleList m_particles;
00064
00065
00066 CParticleFilterData() : m_particles(0)
00067 { }
00068
00069
00070
00071 void clearParticles()
00072 {
00073 MRPT_START
00074 for (typename CParticleList::iterator it=m_particles.begin();it!=m_particles.end();it++)
00075 if (it->d) delete it->d;
00076 m_particles.clear();
00077 MRPT_END
00078 }
00079
00080
00081
00082 virtual ~CParticleFilterData()
00083 {
00084 MRPT_START
00085 clearParticles();
00086 MRPT_END
00087 }
00088
00089
00090
00091
00092 void writeParticlesToStream( utils::CStream &out ) const
00093 {
00094 MRPT_START
00095 uint32_t n = static_cast<uint32_t>(m_particles.size());
00096 out << n;
00097 typename CParticleList::const_iterator it;
00098 for (it=m_particles.begin();it!=m_particles.end();it++)
00099 out << it->log_w << (*it->d);
00100 MRPT_END
00101 }
00102
00103
00104
00105
00106 void readParticlesFromStream(utils::CStream &in)
00107 {
00108 MRPT_START
00109 clearParticles();
00110 uint32_t n;
00111 in >> n;
00112 m_particles.resize(n);
00113 typename CParticleList::iterator it;
00114 for (it=m_particles.begin();it!=m_particles.end();it++)
00115 {
00116 in >> it->log_w;
00117 it->d = new T();
00118 in >> *it->d;
00119 }
00120 MRPT_END
00121 }
00122
00123
00124
00125
00126 void getWeights( vector_double &out_logWeights ) const
00127 {
00128 MRPT_START
00129 out_logWeights.resize(m_particles.size());
00130 vector_double::iterator it;
00131 typename CParticleList::const_iterator it2;
00132 for (it=out_logWeights.begin(),it2=m_particles.begin();it2!=m_particles.end();it++,it2++)
00133 *it = it2->log_w;
00134 MRPT_END
00135 }
00136
00137
00138
00139 const CParticleData * getMostLikelyParticle() const
00140 {
00141 MRPT_START
00142 const CParticleData *ret = NULL;
00143 ASSERT_(m_particles.size()>0)
00144
00145 typename CParticleList::const_iterator it;
00146 for (it=m_particles.begin();it!=m_particles.end();it++)
00147 {
00148 if (ret==NULL || it->log_w > ret->log_w)
00149 ret = &(*it);
00150 }
00151 return ret;
00152 MRPT_END
00153 }
00154
00155
00156 };
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167 #define IMPLEMENT_PARTICLE_FILTER_CAPABLE(T) \
00168 public: \
00169 virtual double getW(size_t i) const \
00170 { \
00171 MRPT_START; \
00172 if (i>=m_particles.size()) THROW_EXCEPTION_CUSTOM_MSG1("Index %i is out of range!",(int)i); \
00173 return m_particles[i].log_w; \
00174 MRPT_END; \
00175 } \
00176 virtual void setW(size_t i, double w) \
00177 { \
00178 MRPT_START; \
00179 if (i>=m_particles.size()) THROW_EXCEPTION_CUSTOM_MSG1("Index %i is out of range!",(int)i); \
00180 m_particles[i].log_w = w; \
00181 MRPT_END; \
00182 } \
00183 virtual size_t particlesCount() const { return m_particles.size(); } \
00184 virtual double normalizeWeights( double *out_max_log_w = NULL ) \
00185 { \
00186 MRPT_START; \
00187 CParticleList::iterator it;\
00188 \
00189 if (!m_particles.size()) return 0; \
00190 double minW,maxW; \
00191 minW = maxW = m_particles[0].log_w; \
00192 \
00193 for (it=m_particles.begin();it!=m_particles.end();it++) \
00194 { \
00195 maxW = std::max<double>( maxW, it->log_w ); \
00196 minW = std::min<double>( minW, it->log_w ); \
00197 } \
00198 \
00199 for (it=m_particles.begin();it!=m_particles.end();it++) \
00200 it->log_w -= maxW; \
00201 if (out_max_log_w) \
00202 *out_max_log_w = maxW; \
00203 \
00204 return exp(maxW-minW); \
00205 MRPT_END; \
00206 } \
00207 virtual double ESS() \
00208 { \
00209 MRPT_START; \
00210 CParticleList::iterator it; \
00211 double cum = 0; \
00212 \
00213 \
00214 double sumLinearWeights = 0; \
00215 for (it=m_particles.begin();it!=m_particles.end();it++) \
00216 sumLinearWeights += exp( it->log_w ); \
00217 \
00218 for (it=m_particles.begin();it!=m_particles.end();it++) \
00219 cum+= utils::square( exp( it->log_w ) / sumLinearWeights ); \
00220 \
00221 if (cum==0) \
00222 return 0; \
00223 else return 1.0/(m_particles.size()*cum); \
00224 MRPT_END; \
00225 } \
00226 \
00227 virtual void performSubstitution( const std::vector<size_t> &indx) \
00228 { \
00229 MRPT_START; \
00230 CParticleList parts; \
00231 CParticleList::iterator itDest,itSrc; \
00232 size_t M_old = m_particles.size(); \
00233 size_t i,j,lastIndxOld = 0; \
00234 std::vector<bool> oldParticlesReused(M_old,false); \
00235 std::vector<bool>::const_iterator oldPartIt; \
00236 std::vector<size_t> sorted_indx(indx); \
00237 std::vector<size_t>::iterator sort_idx_it; \
00238 \
00239 std::sort( sorted_indx.begin(), sorted_indx.end() ); \
00240 \
00241 parts.resize( sorted_indx.size() ); \
00242 for (i=0,itDest=parts.begin();itDest!=parts.end();i++,itDest++) \
00243 { \
00244 const size_t sorted_idx = sorted_indx[i]; \
00245 itDest->log_w = m_particles[ sorted_idx ].log_w; \
00246 \
00247 for (j=lastIndxOld;j<sorted_idx;j++) \
00248 { \
00249 if (!oldParticlesReused[j]) \
00250 { \
00251 delete m_particles[j].d; \
00252 m_particles[j].d = NULL; \
00253 } \
00254 } \
00255 \
00256 lastIndxOld = sorted_idx; \
00257 \
00258 \
00259 if (!oldParticlesReused[sorted_idx]) \
00260 { \
00261 \
00262 parts[i].d = m_particles[ sorted_idx ].d; \
00263 oldParticlesReused[sorted_idx]=true; \
00264 } \
00265 else \
00266 { \
00267 \
00268 ASSERT_( m_particles[ sorted_idx ].d != NULL); \
00269 parts[i].d = new T( *m_particles[ sorted_idx ].d ); \
00270 } \
00271 } \
00272 \
00273 for (itSrc=m_particles.begin(),oldPartIt=oldParticlesReused.begin();itSrc!=m_particles.end();itSrc++,oldPartIt++) \
00274 if (! *oldPartIt ) \
00275 { \
00276 delete itSrc->d; \
00277 itSrc->d = NULL; \
00278 } \
00279 \
00280 m_particles.resize( parts.size() ); \
00281 for (itSrc=parts.begin(),itDest=m_particles.begin(); itSrc!=parts.end(); itSrc++, itDest++ ) \
00282 { \
00283 itDest->log_w = itSrc->log_w; \
00284 itDest->d = itSrc->d; \
00285 itSrc->d = NULL; \
00286 } \
00287 parts.clear(); \
00288 MRPT_END; \
00289 } \
00290
00291 }
00292 }
00293 #endif