mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
hardshrink.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_ANN_LAYER_HARDSHRINK_HPP
17 #define MLPACK_METHODS_ANN_LAYER_HARDSHRINK_HPP
18 
19 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
44 template <
45  typename InputDataType = arma::mat,
46  typename OutputDataType = arma::mat
47 >
49 {
50  public:
59  HardShrink(const double lambda = 0.5);
60 
68  template<typename InputType, typename OutputType>
69  void Forward(const InputType& input, OutputType& output);
70 
80  template<typename DataType>
81  void Backward(const DataType& input,
82  DataType& gy,
83  DataType& g);
84 
86  OutputDataType const& OutputParameter() const { return outputParameter; }
88  OutputDataType& OutputParameter() { return outputParameter; }
89 
91  OutputDataType const& Delta() const { return delta; }
93  OutputDataType& Delta() { return delta; }
94 
96  double const& Lambda() const { return lambda; }
98  double& Lambda() { return lambda; }
99 
103  template<typename Archive>
104  void serialize(Archive& ar, const unsigned int /* version */);
105 
106  private:
108  OutputDataType delta;
109 
111  OutputDataType outputParameter;
112 
114  double lambda;
115 }; // class HardShrink
116 
117 } // namespace ann
118 } // namespace mlpack
119 
120 // Include implementation.
121 #include "hardshrink_impl.hpp"
122 
123 #endif
OutputDataType const & Delta() const
Get the delta.
Definition: hardshrink.hpp:91
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: hardshrink.hpp:88
The core includes that mlpack expects; standard C++ includes and Armadillo.
double const & Lambda() const
Get the hyperparameter lambda.
Definition: hardshrink.hpp:96
HardShrink(const double lambda=0.5)
Create HardShrink object using specified hyperparameter lambda.
double & Lambda()
Modify the hyperparameter lambda.
Definition: hardshrink.hpp:98
OutputDataType & Delta()
Modify the delta.
Definition: hardshrink.hpp:93
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: hardshrink.hpp:86
Hard Shrink operator is defined as, lambda is set to 0.5 by default.
Definition: hardshrink.hpp:48
void Backward(const DataType &input, DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...