/*!
*
*
* \brief -
*
* \author O.Krause
* \date 2017
*
*
* \par Copyright 1995-2017 Shark Development Team
*
*
* This file is part of Shark.
*
*
* Shark is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Shark is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with Shark. If not, see .
*
*/
#ifndef MODELS_DROPOUTLAYER_H
#define MODELS_DROPOUTLAYER_H
#include
#include
#include
namespace shark{
template
class DropoutLayer : public AbstractModel{
private:
typedef AbstractModel base_type;
typedef blas::matrix MatrixType;
struct InternalState: public State{
MatrixType mask;
};
Shape m_shape;
random::rng_type* mep_rng;
double m_dropoutProbability;
public:
typedef typename base_type::BatchInputType BatchInputType;
typedef typename base_type::BatchOutputType BatchOutputType;
typedef typename base_type::ParameterVectorType ParameterVectorType;
DropoutLayer(Shape const& inputShape, double probability = 0.5, random::rng_type& rng = random::globalRng)
: m_shape(inputShape), mep_rng(&rng), m_dropoutProbability(probability){
base_type::m_features |= base_type::HAS_FIRST_PARAMETER_DERIVATIVE;
base_type::m_features |= base_type::HAS_FIRST_INPUT_DERIVATIVE;
}
/// \brief From INameable: return the class name.
std::string name() const
{ return "DropoutLayer"; }
/// obtain the parameter vector
ParameterVectorType parameterVector() const{
return ParameterVectorType();
}
/// overwrite the parameter vector
void setParameterVector(ParameterVectorType const& newParameters){
SIZE_CHECK(newParameters.size() == 0);
}
/// return the number of parameter
size_t numberOfParameters() const{
return 0;
}
///\brief Returns the expected shape of the input
Shape inputShape() const{
return m_shape;
}
///\brief Returns the shape of the output
Shape outputShape() const{
return m_shape;
}
boost::shared_ptr createState()const{
return boost::shared_ptr(new InternalState());
}
using base_type::eval;
void eval(BatchInputType const& inputs, BatchOutputType& outputs)const{
outputs.resize(inputs.size1(),inputs.size2());
noalias(outputs) = inputs;
for(std::size_t i = 0; i != outputs.size1(); ++i){
for(std::size_t j = 0; j != outputs.size2(); ++j){
if(!random::coinToss(*mep_rng,m_dropoutProbability)){
outputs(i,j) = 0;
}
}
}
}
void eval(VectorType const& input, VectorType& output)const {
output.resize(input.size());
noalias(output) = input;
for(std::size_t j = 0; j != output.size(); ++j){
if(!random::coinToss(*mep_rng,m_dropoutProbability)){
output(j) = 0;
}
}
}
void eval(BatchInputType const& inputs, BatchOutputType& outputs, State& state)const{
MatrixType& mask = state.toState().mask;
outputs.resize(inputs.size1(),inputs.size2());
mask.resize(inputs.size1(),inputs.size2());
for(std::size_t i = 0; i != outputs.size1(); ++i){
for(std::size_t j = 0; j != outputs.size2(); ++j){
mask(i,j) = random::coinToss(*mep_rng,m_dropoutProbability);
}
}
noalias(outputs) = inputs * mask;
}
///\brief Calculates the first derivative w.r.t the parameters and summing them up over all patterns of the last computed batch
void weightedParameterDerivative(
BatchInputType const& patterns,
BatchOutputType const& outputs,
BatchOutputType const& coefficients,
State const& state,
ParameterVectorType& gradient
)const{
SIZE_CHECK(coefficients.size1()==patterns.size1());
SIZE_CHECK(coefficients.size2()==patterns.size2());
}
///\brief Calculates the first derivative w.r.t the inputs and summs them up over all patterns of the last computed batch
void weightedInputDerivative(
BatchInputType const & patterns,
BatchOutputType const & outputs,
BatchOutputType const & coefficients,
State const& state,
BatchInputType& derivative
)const{
SIZE_CHECK(coefficients.size1() == patterns.size1());
SIZE_CHECK(coefficients.size2() == patterns.size2());
MatrixType const& mask = state.toState().mask;
derivative.resize(coefficients.size1(),coefficients.size2());
noalias(derivative) = coefficients * mask;
}
/// From ISerializable
void read(InArchive& archive){archive >> m_dropoutProbability;}
/// From ISerializable
void write(OutArchive& archive) const{ archive << m_dropoutProbability;}
};
}
#endif