mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
reparametrization.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer_types.hpp"
19 #include "../activation_functions/softplus_function.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
56 class Reparametrization
57 {
58  public:
61 
70  Reparametrization(const size_t latentSize,
71  const bool stochastic = true,
72  const bool includeKl = true,
73  const double beta = 1);
74 
82  template<typename eT>
83  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>& input,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g);
98 
100  OutputDataType const& OutputParameter() const { return outputParameter; }
102  OutputDataType& OutputParameter() { return outputParameter; }
103 
105  OutputDataType const& Delta() const { return delta; }
107  OutputDataType& Delta() { return delta; }
108 
110  size_t const& OutputSize() const { return latentSize; }
112  size_t& OutputSize() { return latentSize; }
113 
115  double Loss()
116  {
117  if (!includeKl)
118  return 0;
119 
120  return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
121  - arma::pow(mean, 2) + 1) / mean.n_cols;
122  }
123 
127  template<typename Archive>
128  void serialize(Archive& ar, const unsigned int /* version */);
129 
130  private:
132  size_t latentSize;
133 
135  bool stochastic;
136 
138  bool includeKl;
139 
141  double beta;
142 
144  OutputDataType delta;
145 
147  OutputDataType gaussianSample;
148 
150  OutputDataType mean;
151 
154  OutputDataType preStdDev;
155 
157  OutputDataType stdDev;
158 
160  OutputDataType outputParameter;
161 }; // class Reparametrization
162 
163 } // namespace ann
164 } // namespace mlpack
165 
166 // Include implementation.
167 #include "reparametrization_impl.hpp"
168 
169 #endif
OutputDataType & Delta()
Modify the delta.
OutputDataType const & Delta() const
Get the delta.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
double Loss()
Get the KL divergence with standard normal.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & OutputParameter()
Modify the output parameter.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType const & OutputParameter() const
Get the output parameter.
size_t & OutputSize()
Modify the output size.
Reparametrization()
Create the Reparametrization object.
size_t const & OutputSize() const
Get the output size.