mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
dice_loss.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_DICE_LOSS_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
44 template <
45  typename InputDataType = arma::mat,
46  typename OutputDataType = arma::mat
47 >
48 class DiceLoss
49 {
50  public:
56  DiceLoss(const double smooth = 1);
57 
64  template<typename InputType, typename TargetType>
65  typename InputType::elem_type Forward(const InputType& input,
66  const TargetType& target);
67 
75  template<typename InputType, typename TargetType, typename OutputType>
76  void Backward(const InputType& input,
77  const TargetType& target,
78  OutputType& output);
79 
81  OutputDataType& OutputParameter() const { return outputParameter; }
83  OutputDataType& OutputParameter() { return outputParameter; }
84 
86  double Smooth() const { return smooth; }
88  double& Smooth() { return smooth; }
89 
93  template<typename Archive>
94  void serialize(Archive& ar, const unsigned int /* version */);
95 
96  private:
98  OutputDataType outputParameter;
99 
101  double smooth;
102 }; // class DiceLoss
103 
104 } // namespace ann
105 } // namespace mlpack
106 
107 // Include implementation.
108 #include "dice_loss_impl.hpp"
109 
110 #endif
The dice loss performance function measures the network&#39;s performance according to the dice coefficie...
Definition: dice_loss.hpp:48
OutputDataType & OutputParameter() const
Get the output parameter.
Definition: dice_loss.hpp:81
double Smooth() const
Get the smooth.
Definition: dice_loss.hpp:86
The core includes that mlpack expects; standard C++ includes and Armadillo.
double & Smooth()
Modify the smooth.
Definition: dice_loss.hpp:88
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
DiceLoss(const double smooth=1)
Create the DiceLoss object.
InputType::elem_type Forward(const InputType &input, const TargetType &target)
Computes the dice loss function.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: dice_loss.hpp:83