//===========================================================================
/*!
*
*
* \brief cross-validation error for selection of hyper-parameters
*
*
* \author T. Glasmachers, O. Krause
* \date 2007-2012
*
*
* \par Copyright 1995-2017 Shark Development Team
*
*
* This file is part of Shark.
*
*
* Shark is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Shark is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with Shark. If not, see .
*
*/
//===========================================================================
#ifndef SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
#define SHARK_OBJECTIVEFUNCTIONS_CROSSVALIDATIONERROR_H
#include
#include
#include
#include
#include
namespace shark {
///
/// \brief Cross-validation error for selection of hyper-parameters.
///
/// \par
/// The cross-validation error is useful for evaluating
/// how well a model performs on a problem. It is regularly
/// used for model selection.
///
/// \par
/// In Shark, the cross-validation procedure is abstracted
/// as follows:
/// First, the given point is written into an IParameterizable
/// object (such as a regularizer or a trainer). Then a model
/// is trained with a trainer with the given settings on a
/// number of folds and evaluated on the corresponding validation
/// sets with a cost function. The average cost function value
/// over all folds is returned.
///
/// \par
/// Thus, the cross-validation procedure requires a "meta"
/// IParameterizable object, a model, a trainer, a data set,
/// and a cost function.
///
template
class CrossValidationError : public AbstractObjectiveFunction< RealVector, double >
{
public:
typedef typename ModelTypeT::InputType InputType;
typedef typename ModelTypeT::OutputType OutputType;
typedef LabelTypeT LabelType;
typedef LabeledData DatasetType;
typedef CVFolds FoldsType;
typedef ModelTypeT ModelType;
typedef AbstractTrainer TrainerType;
typedef AbstractCost CostType;
private:
typedef SingleObjectiveFunction base_type;
FoldsType m_folds;
IParameterizable<>* mep_meta;
ModelType* mep_model;
TrainerType* mep_trainer;
CostType* mep_cost;
public:
CrossValidationError(
FoldsType const& dataFolds,
IParameterizable<>* meta,
ModelType* model,
TrainerType* trainer,
CostType* cost)
: m_folds(dataFolds)
, mep_meta(meta)
, mep_model(model)
, mep_trainer(trainer)
, mep_cost(cost)
{ }
/// \brief From INameable: return the class name.
std::string name() const
{
return "CrossValidationError<"
+ mep_model->name() + ","
+ mep_trainer->name() + ","
+ mep_cost->name() + ">";
}
std::size_t numberOfVariables()const{
return mep_meta->numberOfParameters();
}
/// Evaluate the cross-validation error:
/// train sub-models, evaluate objective,
/// return the average.
double eval(RealVector const& parameters) const {
this->m_evaluationCounter++;
mep_meta->setParameterVector(parameters);
double ret = 0.0;
for (size_t setID=0; setID != m_folds.size(); ++setID) {
DatasetType train = m_folds.training(setID);
DatasetType validation = m_folds.validation(setID);
mep_trainer->train(*mep_model, train);
Data output = (*mep_model)(validation.inputs());
ret += mep_cost->eval(validation.labels(), output);
}
return ret / m_folds.size();
}
};
}
#endif