20 #ifndef __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP 21 #define __MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCHLEARNING_HPP 52 double min = -DBL_MIN,
57 template<
typename MatType>
58 void Initialize(
const MatType& dataset,
const size_t rank)
60 const size_t n = dataset.n_rows;
61 const size_t m = dataset.n_cols;
76 template<
typename MatType>
88 arma::mat deltaW(n, r);
91 for(
size_t i = 0;i < n;i++)
93 for(
size_t j = 0;j < m;j++)
96 if((val = V(i, j)) != 0)
97 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
98 arma::trans(H.col(j));
100 if(
kw != 0) deltaW.row(i) -=
kw * W.row(i);
116 template<
typename MatType>
128 arma::mat deltaH(r, m);
131 for(
size_t j = 0;j < m;j++)
133 for(
size_t i = 0;i < n;i++)
136 if((val = V(i, j)) != 0)
137 deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) *
138 arma::trans(W.row(i));
140 if(
kh != 0) deltaH.col(j) -=
kh * H.col(j);
166 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(
const arma::sp_mat& V,
176 arma::mat deltaW(n, r);
179 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
181 size_t row = it.row();
182 size_t col = it.col();
183 deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
184 arma::trans(H.col(col));
187 if(
kw != 0)
for(
size_t i = 0; i < n; i++)
189 deltaW.row(i) -=
kw * W.row(i);
197 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(
const arma::sp_mat& V,
207 arma::mat deltaH(r, m);
210 for(arma::sp_mat::const_iterator it = V.begin();it != V.end();it++)
212 size_t row = it.row();
213 size_t col = it.col();
214 deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
215 arma::trans(W.row(row));
218 if(
kh != 0)
for(
size_t j = 0; j < m; j++)
220 deltaH.col(j) -=
kh * H.col(j);
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(const MatType &dataset, const size_t rank)
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9, double min=-DBL_MIN, double max=DBL_MAX)
SVD Batch learning constructor.
This class implements SVD batch learning with momentum.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.