mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
fast_lstm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <limits>
18 
19 namespace mlpack {
20 namespace ann {
21 
62 template <
63  typename InputDataType = arma::mat,
64  typename OutputDataType = arma::mat
65 >
66 class FastLSTM
67 {
68  public:
69  // Convenience typedefs.
70  typedef typename InputDataType::elem_type InputElemType;
71  typedef typename OutputDataType::elem_type ElemType;
72 
74  FastLSTM();
75 
83  FastLSTM(const size_t inSize,
84  const size_t outSize,
85  const size_t rho = std::numeric_limits<size_t>::max());
86 
94  template<typename InputType, typename OutputType>
95  void Forward(const InputType& input, OutputType& output);
96 
106  template<typename InputType, typename ErrorType, typename GradientType>
107  void Backward(const InputType& input,
108  const ErrorType& gy,
109  GradientType& g);
110 
111  /*
112  * Reset the layer parameter.
113  */
114  void Reset();
115 
116  /*
117  * Resets the cell to accept a new input. This breaks the BPTT chain starts a
118  * new one.
119  *
120  * @param size The current maximum number of steps through time.
121  */
122  void ResetCell(const size_t size);
123 
124  /*
125  * Calculate the gradient using the output delta and the input activation.
126  *
127  * @param input The input parameter used for calculating the gradient.
128  * @param error The calculated error.
129  * @param gradient The calculated gradient.
130  */
131  template<typename InputType, typename ErrorType, typename GradientType>
132  void Gradient(const InputType& input,
133  const ErrorType& error,
134  GradientType& gradient);
135 
137  size_t Rho() const { return rho; }
139  size_t& Rho() { return rho; }
140 
142  OutputDataType const& Parameters() const { return weights; }
144  OutputDataType& Parameters() { return weights; }
145 
147  OutputDataType const& OutputParameter() const { return outputParameter; }
149  OutputDataType& OutputParameter() { return outputParameter; }
150 
152  OutputDataType const& Delta() const { return delta; }
154  OutputDataType& Delta() { return delta; }
155 
157  OutputDataType const& Gradient() const { return grad; }
159  OutputDataType& Gradient() { return grad; }
160 
164  template<typename Archive>
165  void serialize(Archive& ar, const unsigned int /* version */);
166 
167  private:
174  template<typename InputType, typename OutputType>
175  void FastSigmoid(const InputType& input, OutputType& sigmoids)
176  {
177  for (size_t i = 0; i < input.n_elem; ++i)
178  sigmoids(i) = FastSigmoid(input(i));
179  }
180 
187  ElemType FastSigmoid(const InputElemType data)
188  {
189  ElemType x = 0.5 * data;
190  ElemType z;
191  if (x >= 0)
192  {
193  if (x < 1.7)
194  z = (1.5 * x / (1 + x));
195  else if (x < 3)
196  z = (0.935409070603099 + 0.0458812946797165 * (x - 1.7));
197  else
198  z = 0.99505475368673;
199  }
200  else
201  {
202  ElemType xx = -x;
203  if (xx < 1.7)
204  z = -(1.5 * xx / (1 + xx));
205  else if (xx < 3)
206  z = -(0.935409070603099 + 0.0458812946797165 * (xx - 1.7));
207  else
208  z = -0.99505475368673;
209  }
210 
211  return 0.5 * (z + 1.0);
212  }
213 
215  size_t inSize;
216 
218  size_t outSize;
219 
221  size_t rho;
222 
224  size_t forwardStep;
225 
227  size_t backwardStep;
228 
230  size_t gradientStep;
231 
233  OutputDataType weights;
234 
236  OutputDataType prevOutput;
237 
239  size_t batchSize;
240 
242  size_t batchStep;
243 
246  size_t gradientStepIdx;
247 
249  OutputDataType cellActivationError;
250 
252  OutputDataType delta;
253 
255  OutputDataType grad;
256 
258  OutputDataType outputParameter;
259 
261  OutputDataType output2GateWeight;
262 
264  OutputDataType input2GateWeight;
265 
267  OutputDataType input2GateBias;
268 
270  OutputDataType gate;
271 
273  OutputDataType gateActivation;
274 
276  OutputDataType stateActivation;
277 
279  OutputDataType cell;
280 
282  OutputDataType cellActivation;
283 
285  OutputDataType forgetGateError;
286 
288  OutputDataType prevError;
289 
291  OutputDataType outParameter;
292 
294  size_t rhoSize;
295 
297  size_t bpttSteps;
298 }; // class FastLSTM
299 
300 } // namespace ann
301 } // namespace mlpack
302 
303 // Include implementation.
304 #include "fast_lstm_impl.hpp"
305 
306 #endif
size_t Rho() const
Get the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:137
OutputDataType const & Delta() const
Get the delta.
Definition: fast_lstm.hpp:152
OutputDataType & Gradient()
Modify the gradient.
Definition: fast_lstm.hpp:159
OutputDataType & Delta()
Modify the delta.
Definition: fast_lstm.hpp:154
void Backward(const InputType &input, const ErrorType &gy, GradientType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType::elem_type ElemType
Definition: fast_lstm.hpp:71
size_t & Rho()
Modify the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:139
FastLSTM()
Create the Fast LSTM object.
InputDataType::elem_type InputElemType
Definition: fast_lstm.hpp:70
OutputDataType const & Gradient() const
Get the gradient.
Definition: fast_lstm.hpp:157
OutputDataType const & Parameters() const
Get the parameters.
Definition: fast_lstm.hpp:142
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: fast_lstm.hpp:149
void ResetCell(const size_t size)
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: fast_lstm.hpp:147
OutputDataType & Parameters()
Modify the parameters.
Definition: fast_lstm.hpp:144
An implementation of a faster version of the Fast LSTM network layer.
Definition: fast_lstm.hpp:66
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...