mlpack  2.0.1
em_fit.hpp
Go to the documentation of this file.
1 
16 #ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
17 #define __MLPACK_METHODS_GMM_EM_FIT_HPP
18 
19 #include <mlpack/core.hpp>
20 
21 // Default clustering mechanism.
23 // Default covariance matrix constraint.
25 
26 namespace mlpack {
27 namespace gmm {
28 
42 template<typename InitialClusteringType = kmeans::KMeans<>,
43  typename CovarianceConstraintPolicy = PositiveDefiniteConstraint>
44 class EMFit
45 {
46  public:
64  EMFit(const size_t maxIterations = 300,
65  const double tolerance = 1e-10,
66  InitialClusteringType clusterer = InitialClusteringType(),
67  CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
68 
84  void Estimate(const arma::mat& observations,
85  std::vector<distribution::GaussianDistribution>& dists,
86  arma::vec& weights,
87  const bool useInitialModel = false);
88 
106  void Estimate(const arma::mat& observations,
107  const arma::vec& probabilities,
108  std::vector<distribution::GaussianDistribution>& dists,
109  arma::vec& weights,
110  const bool useInitialModel = false);
111 
113  const InitialClusteringType& Clusterer() const { return clusterer; }
115  InitialClusteringType& Clusterer() { return clusterer; }
116 
118  const CovarianceConstraintPolicy& Constraint() const { return constraint; }
120  CovarianceConstraintPolicy& Constraint() { return constraint; }
121 
123  size_t MaxIterations() const { return maxIterations; }
125  size_t& MaxIterations() { return maxIterations; }
126 
128  double Tolerance() const { return tolerance; }
130  double& Tolerance() { return tolerance; }
131 
133  template<typename Archive>
134  void Serialize(Archive& ar, const unsigned int version);
135 
136  private:
147  void InitialClustering(const arma::mat& observations,
148  std::vector<distribution::GaussianDistribution>& dists,
149  arma::vec& weights);
150 
161  double LogLikelihood(const arma::mat& data,
162  const std::vector<distribution::GaussianDistribution>&
163  dists,
164  const arma::vec& weights) const;
165 
169  double tolerance;
171  InitialClusteringType clusterer;
173  CovarianceConstraintPolicy constraint;
174 };
175 
176 } // namespace gmm
177 } // namespace mlpack
178 
179 // Include implementation.
180 #include "em_fit_impl.hpp"
181 
182 #endif
This class contains methods which can fit a GMM to observations using the EM algorithm.
Definition: em_fit.hpp:44
void Estimate(const arma::mat &observations, std::vector< distribution::GaussianDistribution > &dists, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm.
double LogLikelihood(const arma::mat &data, const std::vector< distribution::GaussianDistribution > &dists, const arma::vec &weights) const
Calculate the log-likelihood of a model.
double & Tolerance()
Modify the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:130
const CovarianceConstraintPolicy & Constraint() const
Get the covariance constraint policy class.
Definition: em_fit.hpp:118
Linear algebra utility functions, generally performed on matrices or vectors.
void InitialClustering(const arma::mat &observations, std::vector< distribution::GaussianDistribution > &dists, arma::vec &weights)
Run the clusterer, and then turn the cluster assignments into Gaussians.
size_t maxIterations
Maximum iterations of EM algorithm.
Definition: em_fit.hpp:167
CovarianceConstraintPolicy constraint
Object which applies constraints to the covariance matrix.
Definition: em_fit.hpp:173
size_t & MaxIterations()
Modify the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:125
size_t MaxIterations() const
Get the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:123
InitialClusteringType & Clusterer()
Modify the clusterer.
Definition: em_fit.hpp:115
EMFit(const size_t maxIterations=300, const double tolerance=1e-10, InitialClusteringType clusterer=InitialClusteringType(), CovarianceConstraintPolicy constraint=CovarianceConstraintPolicy())
Construct the EMFit object, optionally passing an InitialClusteringType object (just in case it needs...
CovarianceConstraintPolicy & Constraint()
Modify the covariance constraint policy class.
Definition: em_fit.hpp:120
double Tolerance() const
Get the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:128
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
const InitialClusteringType & Clusterer() const
Get the clusterer.
Definition: em_fit.hpp:113
InitialClusteringType clusterer
Object which will perform the clustering.
Definition: em_fit.hpp:171
double tolerance
Tolerance for convergence of EM.
Definition: em_fit.hpp:169
void Serialize(Archive &ar, const unsigned int version)
Serialize the fitter.