/*!
*
*
* \brief error function for supervised learning
*
*
*
* \author T.Voss, T. Glasmachers, O.Krause
* \date 2010-2011
*
*
* \par Copyright 1995-2017 Shark Development Team
*
*
* This file is part of Shark.
*
*
* Shark is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Shark is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with Shark. If not, see .
*
*/
#ifndef SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
#define SHARK_OBJECTIVEFUNCTIONS_ERRORFUNCTION_H
#include
#include
#include
#include
#include
#include "Impl/ErrorFunction.inl"
#include
namespace shark{
///
/// \brief Objective function for supervised learning
///
/// \par
/// An ErrorFunction object is an objective function for
/// learning the parameters of a model from data by means
/// of minimization of a cost function. The value of the
/// objective function is the cost of the model predictions
/// on the training data, given the targets.
/// \par
/// It supports mini-batch learning using an optional fourth argument to
/// The constructor. With mini-batch learning enabled, each iteration a random
/// batch is taken from the dataset. Thus the size of the minibatch is the size of the batches in
/// the datasets. Normalization ensures that batches of different sizes have approximately the same
/// magnitude of error and derivative.
///
///\par
/// It automatically infers the input und label type from the given dataset and the output type
/// of the model in the constructor and ensures that Model and loss match. Thus the user does
/// not need to provide the types as template parameters.
template
class ErrorFunction : public AbstractObjectiveFunction
{
private:
typedef AbstractObjectiveFunction FunctionType;
public:
typedef typename FunctionType::ResultType ResultType;
typedef typename FunctionType::FirstOrderDerivative FirstOrderDerivative;
template
ErrorFunction(
LabeledData const& dataset,
AbstractModel* model,
AbstractLoss* loss,
bool useMiniBatches = false
){
m_regularizer = nullptr;
mp_wrapper.reset(new detail::ErrorFunctionImpl(dataset,model,loss, useMiniBatches));
this -> m_features = mp_wrapper -> features();
}
template
ErrorFunction(
WeightedLabeledData const& dataset,
AbstractModel* model,
AbstractLoss* loss
){
m_regularizer = nullptr;
mp_wrapper.reset(new detail::WeightedErrorFunctionImpl(dataset,model,loss));
this -> m_features = mp_wrapper -> features();
}
ErrorFunction(ErrorFunction const& op)
:mp_wrapper(op.mp_wrapper->clone()){
this -> m_features = mp_wrapper -> features();
}
ErrorFunction& operator=(ErrorFunction const& op){
ErrorFunction copy(op);
swap(copy.mp_wrapper,mp_wrapper);
swap(copy.m_features, this->m_features);
return *this;
}
std::string name() const
{ return "ErrorFunction"; }
void setRegularizer(double factor, FunctionType* regularizer){
m_regularizer = regularizer;
m_regularizationStrength = factor;
}
SearchPointType proposeStartingPoint()const {
return mp_wrapper -> proposeStartingPoint();
}
std::size_t numberOfVariables()const{
return mp_wrapper -> numberOfVariables();
}
void init(){
mp_wrapper->setRng(this->mep_rng);
mp_wrapper-> init();
}
double eval(SearchPointType const& input) const{
++this->m_evaluationCounter;
double value = mp_wrapper -> eval(input);
if(m_regularizer)
value += m_regularizationStrength * m_regularizer->eval(input);
return value;
}
ResultType evalDerivative( SearchPointType const& input, FirstOrderDerivative & derivative ) const{
++this->m_evaluationCounter;
double value = mp_wrapper -> evalDerivative(input,derivative);
if(m_regularizer){
FirstOrderDerivative regularizerDerivative;
value += m_regularizationStrength * m_regularizer->evalDerivative(input,regularizerDerivative);
noalias(derivative) += m_regularizationStrength*regularizerDerivative;
}
return value;
}
private:
boost::scoped_ptr > mp_wrapper;
FunctionType* m_regularizer;
double m_regularizationStrength;
};
}
#endif