mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
brnn.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_BRNN_HPP
14 #define MLPACK_METHODS_ANN_BRNN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
20 #include "visitor/copy_visitor.hpp"
23 
29 
30 #include <ensmallen.hpp>
31 
32 namespace mlpack {
33 namespace ann {
34 
41 template<
42  typename OutputLayerType = NegativeLogLikelihood<>,
43  typename MergeLayerType = Concat<>,
44  typename MergeOutputType = LogSoftMax<>,
45  typename InitializationRuleType = RandomInitialization,
46  typename... CustomLayers
47 >
48 class BRNN
49 {
50  public:
52  using NetworkType = BRNN<OutputLayerType,
53  MergeLayerType,
54  MergeOutputType,
55  InitializationRuleType,
56  CustomLayers...>;
57 
73  BRNN(const size_t rho,
74  const bool single = false,
75  OutputLayerType outputLayer = OutputLayerType(),
76  MergeLayerType* mergeLayer = new MergeLayerType(),
77  MergeOutputType* mergeOutput = new MergeOutputType(),
78  InitializationRuleType initializeRule = InitializationRuleType());
79 
80  ~BRNN();
81 
91  template<typename OptimizerType>
92  typename std::enable_if<
93  HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
94  ::value, void>::type
95  WarnMessageMaxIterations(OptimizerType& optimizer, size_t samples) const;
96 
105  template<typename OptimizerType>
106  typename std::enable_if<
107  !HasMaxIterations<OptimizerType, size_t&(OptimizerType::*)()>
108  ::value, void>::type
109  WarnMessageMaxIterations(OptimizerType& optimizer, size_t samples) const;
110 
134  template<typename OptimizerType>
135  double Train(arma::cube predictors,
136  arma::cube responses,
137  OptimizerType& optimizer);
138 
162  template<typename OptimizerType = ens::StandardSGD>
163  double Train(arma::cube predictors, arma::cube responses);
164 
184  void Predict(arma::cube predictors,
185  arma::cube& results,
186  const size_t batchSize = 256);
187 
201  double Evaluate(const arma::mat& parameters,
202  const size_t begin,
203  const size_t batchSize,
204  const bool deterministic);
205 
218  double Evaluate(const arma::mat& parameters,
219  const size_t begin,
220  const size_t batchSize);
221 
235  template<typename GradType>
236  double EvaluateWithGradient(const arma::mat& parameters,
237  const size_t begin,
238  GradType& gradient,
239  const size_t batchSize);
240 
254  void Gradient(const arma::mat& parameters,
255  const size_t begin,
256  arma::mat& gradient,
257  const size_t batchSize);
258 
263  void Shuffle();
264 
265  /*
266  * Add a new module to the model.
267  *
268  * @param args The layer parameter.
269  */
270  template <class LayerType, class... Args>
271  void Add(Args... args);
272 
273  /*
274  * Add a new module to the model.
275  *
276  * @param layer The Layer to be added to the model.
277  */
278  void Add(LayerTypes<CustomLayers...> layer);
279 
281  size_t NumFunctions() const { return numFunctions; }
282 
284  const arma::mat& Parameters() const { return parameter; }
286  arma::mat& Parameters() { return parameter; }
287 
289  const size_t& Rho() const { return rho; }
291  size_t& Rho() { return rho; }
292 
294  const arma::cube& Responses() const { return responses; }
296  arma::cube& Responses() { return responses; }
297 
299  const arma::cube& Predictors() const { return predictors; }
301  arma::cube& Predictors() { return predictors; }
302 
308  void Reset();
309 
313  void ResetParameters();
314 
316  template<typename Archive>
317  void serialize(Archive& ar, const unsigned int /* version */);
318 
319  private:
320  // Helper functions.
325  void ResetDeterministic();
326 
328  size_t rho;
329 
331  OutputLayerType outputLayer;
332 
334  LayerTypes<CustomLayers...> mergeLayer;
335 
337  LayerTypes<CustomLayers...> mergeOutput;
338 
341  InitializationRuleType initializeRule;
342 
344  size_t inputSize;
345 
347  size_t outputSize;
348 
350  size_t targetSize;
351 
353  bool reset;
354 
356  bool single;
357 
359  arma::cube predictors;
360 
362  arma::cube responses;
363 
365  arma::mat parameter;
366 
368  size_t numFunctions;
369 
371  arma::mat error;
372 
374  DeltaVisitor deltaVisitor;
375 
377  OutputParameterVisitor outputParameterVisitor;
378 
380  std::vector<arma::mat> forwardRNNOutputParameter;
381 
383  std::vector<arma::mat> backwardRNNOutputParameter;
384 
386  WeightSizeVisitor weightSizeVisitor;
387 
389  ResetVisitor resetVisitor;
390 
392  DeleteVisitor deleteVisitor;
393 
395  CopyVisitor<CustomLayers...> copyVisitor;
396 
398  bool deterministic;
399 
401  arma::mat forwardGradient;
402 
404  arma::mat backwardGradient;
405 
407  arma::mat totalGradient;
408 
410  RNN<OutputLayerType, InitializationRuleType, CustomLayers...> forwardRNN;
411 
413  RNN<OutputLayerType, InitializationRuleType, CustomLayers...> backwardRNN;
414 }; // class BRNN
415 
416 } // namespace ann
417 } // namespace mlpack
418 
420 namespace boost {
421 namespace serialization {
422 
423 template<typename OutputLayerType,
424  typename InitializationRuleType,
425  typename MergeLayerType,
426  typename MergeOutputType,
427  typename... CustomLayer>
428 struct version<
429  mlpack::ann::BRNN<OutputLayerType, MergeLayerType, MergeOutputType,
430  InitializationRuleType, CustomLayer...>>
431 {
432  BOOST_STATIC_CONSTANT(int, value = 1);
433 };
434 
435 } // namespace serialization
436 } // namespace boost
437 
438 // Include implementation.
439 #include "brnn_impl.hpp"
440 
441 #endif
DeleteVisitor executes the destructor of the instantiated object.
void ResetParameters()
Reset the module information (weights/parameters).
BaseLayer< ActivationFunction, InputDataType, OutputDataType > CustomLayer
Standard Sigmoid layer.
size_t NumFunctions() const
Return the number of separable functions. (number of predictor points).
Definition: brnn.hpp:281
arma::mat & Parameters()
Modify the initial point for the optimization.
Definition: brnn.hpp:286
double Train(arma::cube predictors, arma::cube responses, OptimizerType &optimizer)
Train the bidirectional recurrent neural network on the given input data using the given optimizer...
This visitor is to support copy constructor for neural network module.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double EvaluateWithGradient(const arma::mat &parameters, const size_t begin, GradType &gradient, const size_t batchSize)
Evaluate the bidirectional recurrent neural network with the given parameters.
void Gradient(const arma::mat &parameters, const size_t begin, arma::mat &gradient, const size_t batchSize)
Evaluate the gradient of the bidirectional recurrent neural network with the given parameters...
WeightSizeVisitor returns the number of weights of the given module.
arma::cube & Responses()
Modify the matrix of responses to the input data points.
Definition: brnn.hpp:296
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
Implementation of a standard recurrent neural network container.
Definition: rnn.hpp:45
void Reset()
Reset the state of the network.
void Predict(arma::cube predictors, arma::cube &results, const size_t batchSize=256)
Predict the responses to a given set of predictors.
ResetVisitor executes the Reset() function.
OutputParameterVisitor exposes the output parameter of the given module.
double Evaluate(const arma::mat &parameters, const size_t begin, const size_t batchSize, const bool deterministic)
Evaluate the bidirectional recurrent neural network with the given parameters.
void Shuffle()
Shuffle the order of function visitation.
const size_t & Rho() const
Return the maximum length of backpropagation through time.
Definition: brnn.hpp:289
void Add(Args...args)
Implementation of a standard bidirectional recurrent neural network container.
Definition: brnn.hpp:48
BRNN(const size_t rho, const bool single=false, OutputLayerType outputLayer=OutputLayerType(), MergeLayerType *mergeLayer=new MergeLayerType(), MergeOutputType *mergeOutput=new MergeOutputType(), InitializationRuleType initializeRule=InitializationRuleType())
Create the BRNN object.
const arma::mat & Parameters() const
Return the initial point for the optimization.
Definition: brnn.hpp:284
void serialize(Archive &ar, const unsigned int)
Serialize the model.
arma::cube & Predictors()
Modify the matrix of data points (predictors).
Definition: brnn.hpp:301
DeltaVisitor exposes the delta parameter of the given module.
const arma::cube & Responses() const
Get the matrix of responses to the input data points.
Definition: brnn.hpp:294
const arma::cube & Predictors() const
Get the matrix of data points (predictors).
Definition: brnn.hpp:299
std::enable_if< HasMaxIterations< OptimizerType, size_t &(OptimizerType::*)()>::value, void >::type WarnMessageMaxIterations(OptimizerType &optimizer, size_t samples) const
Check if the optimizer has MaxIterations() parameter, if it does then check if it&#39;s value is less tha...
size_t & Rho()
Modify the maximum length of backpropagation through time.
Definition: brnn.hpp:291