mlpack  3.1.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
recurrent.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
14 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
15 
16 #include <mlpack/core.hpp>
17 
18 #include "../visitor/delete_visitor.hpp"
19 #include "../visitor/delta_visitor.hpp"
20 #include "../visitor/copy_visitor.hpp"
21 #include "../visitor/output_parameter_visitor.hpp"
22 
23 #include "layer_types.hpp"
24 #include "add_merge.hpp"
25 #include "sequential.hpp"
26 
27 namespace mlpack {
28 namespace ann {
29 
39 template <
40  typename InputDataType = arma::mat,
41  typename OutputDataType = arma::mat,
42  typename... CustomLayers
43 >
44 class Recurrent
45 {
46  public:
51  Recurrent();
52 
54  Recurrent(const Recurrent&);
55 
65  template<typename StartModuleType,
66  typename InputModuleType,
67  typename FeedbackModuleType,
68  typename TransferModuleType>
69  Recurrent(const StartModuleType& start,
70  const InputModuleType& input,
71  const FeedbackModuleType& feedback,
72  const TransferModuleType& transfer,
73  const size_t rho);
74 
82  template<typename eT>
83  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>& /* input */,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g);
98 
99  /*
100  * Calculate the gradient using the output delta and the input activation.
101  *
102  * @param input The input parameter used for calculating the gradient.
103  * @param error The calculated error.
104  * @param gradient The calculated gradient.
105  */
106  template<typename eT>
107  void Gradient(const arma::Mat<eT>& input,
108  const arma::Mat<eT>& error,
109  arma::Mat<eT>& /* gradient */);
110 
112  std::vector<LayerTypes<CustomLayers...> >& Model() { return network; }
113 
115  bool Deterministic() const { return deterministic; }
117  bool& Deterministic() { return deterministic; }
118 
120  OutputDataType const& Parameters() const { return parameters; }
122  OutputDataType& Parameters() { return parameters; }
123 
125  OutputDataType const& OutputParameter() const { return outputParameter; }
127  OutputDataType& OutputParameter() { return outputParameter; }
128 
130  OutputDataType const& Delta() const { return delta; }
132  OutputDataType& Delta() { return delta; }
133 
135  OutputDataType const& Gradient() const { return gradient; }
137  OutputDataType& Gradient() { return gradient; }
138 
142  template<typename Archive>
143  void serialize(Archive& ar, const unsigned int /* version */);
144 
145  private:
147  DeleteVisitor deleteVisitor;
148 
150  CopyVisitor<CustomLayers...> copyVisitor;
151 
153  LayerTypes<CustomLayers...> startModule;
154 
156  LayerTypes<CustomLayers...> inputModule;
157 
159  LayerTypes<CustomLayers...> feedbackModule;
160 
162  LayerTypes<CustomLayers...> transferModule;
163 
165  size_t rho;
166 
168  size_t forwardStep;
169 
171  size_t backwardStep;
172 
174  size_t gradientStep;
175 
177  bool deterministic;
178 
181  bool ownsLayer;
182 
184  OutputDataType parameters;
185 
187  LayerTypes<CustomLayers...> initialModule;
188 
190  LayerTypes<CustomLayers...> recurrentModule;
191 
193  std::vector<LayerTypes<CustomLayers...> > network;
194 
196  LayerTypes<CustomLayers...> mergeModule;
197 
199  DeltaVisitor deltaVisitor;
200 
202  OutputParameterVisitor outputParameterVisitor;
203 
205  std::vector<arma::mat> feedbackOutputParameter;
206 
208  OutputDataType delta;
209 
211  OutputDataType gradient;
212 
214  OutputDataType outputParameter;
215 
217  arma::mat recurrentError;
218 }; // class Recurrent
219 
220 } // namespace ann
221 } // namespace mlpack
222 
223 // Include implementation.
224 #include "recurrent_impl.hpp"
225 
226 #endif
DeleteVisitor executes the destructor of the instantiated object.
OutputDataType const & Delta() const
Get the delta.
Definition: recurrent.hpp:130
std::vector< LayerTypes< CustomLayers...> > & Model()
Get the model modules.
Definition: recurrent.hpp:112
bool & Deterministic()
Modify the value of the deterministic parameter.
Definition: recurrent.hpp:117
This visitor is to support copy constructor for neural network module.
OutputDataType & Delta()
Modify the delta.
Definition: recurrent.hpp:132
boost::variant< Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: recurrent.hpp:125
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType & Gradient()
Modify the gradient.
Definition: recurrent.hpp:137
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent.hpp:135
OutputParameterVisitor exposes the output parameter of the given module.
OutputDataType & Parameters()
Modify the parameters.
Definition: recurrent.hpp:122
Recurrent()
Default constructor—this will create a Recurrent object that can&#39;t be used, so be careful! Make sure ...
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
DeltaVisitor exposes the delta parameter of the given module.
OutputDataType const & Parameters() const
Get the parameters.
Definition: recurrent.hpp:120
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
bool Deterministic() const
The value of the deterministic parameter.
Definition: recurrent.hpp:115
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: recurrent.hpp:127