13 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP 14 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP 19 #include <boost/variant.hpp> 37 const arma::mat& delta,
41 template<
typename LayerType>
48 const arma::mat& input;
51 const arma::mat& delta;
62 typename std::enable_if<
63 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
64 !HasRunCheck<T, bool&(T::*)(void)>::value,
void>::type
65 LayerGradients(T* layer, arma::mat& input)
const;
70 typename std::enable_if<
71 HasGradientCheck<T, arma::mat&(T::*)()>::value &&
72 HasRunCheck<T, bool&(T::*)(void)>::value,
void>::type
73 LayerGradients(T* layer, arma::mat& input)
const;
77 template<
typename T,
typename P>
78 typename std::enable_if<
79 !HasGradientCheck<T, P&(T::*)()>::value,
void>::type
80 LayerGradients(T* layer, P& input)
const;
87 #include "gradient_visitor_impl.hpp" boost::variant< Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *> MoreTypes
GradientVisitor(const arma::mat &input, const arma::mat &delta)
Executes the Gradient() method of the given module using the input and delta parameter.
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
void operator()(LayerType *layer) const
Executes the Gradient() method.