//=========================================================================== /*! * * * \brief super class of all loss functions * * * * \author T. Glasmachers * \date 2010-2011 * * * \par Copyright 1995-2017 Shark Development Team * *

* This file is part of Shark. * * * Shark is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published * by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * Shark is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with Shark. If not, see . * */ #ifndef SHARK_OBJECTIVEFUNCTIONS_LOSS_ABSTRACTLOSS_H #define SHARK_OBJECTIVEFUNCTIONS_LOSS_ABSTRACTLOSS_H #include #include #include namespace shark { /// \brief Loss function interface /// /// \par /// In statistics and machine learning, a loss function encodes /// the severity of getting a label wrong. This is am important /// special case of a cost function (see AbstractCost), where /// the cost is computed as the average loss over a set, also /// known as (empirical) risk. /// /// \par /// It is generally agreed that loss values are non-negative, /// and that the loss of correct prediction is zero. This rule /// is not formally checked, but instead left to the various /// sub-classes. /// template class AbstractLoss : public AbstractCost { public: typedef AbstractCost base_type; typedef OutputT OutputType; typedef LabelT LabelType; typedef RealMatrix MatrixType; typedef typename Batch::type BatchOutputType; typedef typename Batch::type BatchLabelType; /// \brief Const references to LabelType typedef typename ConstProxyReference::type ConstLabelReference; /// \brief Const references to OutputType typedef typename ConstProxyReference::type ConstOutputReference; AbstractLoss(){ this->m_features |= base_type::IS_LOSS_FUNCTION; } /// \brief evaluate the loss for a batch of targets and a prediction /// /// \param target target values /// \param prediction predictions, typically made by a model virtual double eval( BatchLabelType const& target, BatchOutputType const& prediction) const = 0; /// \brief evaluate the loss for a target and a prediction /// /// \param target target value /// \param prediction prediction, typically made by a model virtual double eval( ConstLabelReference target, ConstOutputReference prediction)const{ BatchLabelType labelBatch = Batch::createBatch(target,1); getBatchElement(labelBatch,0)=target; BatchOutputType predictionBatch = Batch::createBatch(prediction,1); getBatchElement(predictionBatch,0)=prediction; return eval(labelBatch,predictionBatch); } /// \brief evaluate the loss and its derivative for a target and a prediction /// /// \param target target value /// \param prediction prediction, typically made by a model /// \param gradient the gradient of the loss function with respect to the prediction virtual double evalDerivative(ConstLabelReference target, ConstOutputReference prediction, OutputType& gradient) const { BatchLabelType labelBatch = Batch::createBatch(target,1); getBatchElement(labelBatch, 0) = target; BatchOutputType predictionBatch = Batch::createBatch(prediction, 1); getBatchElement(predictionBatch, 0) = prediction; BatchOutputType gradientBatch = Batch::createBatch(gradient, 1); double ret = evalDerivative(labelBatch, predictionBatch, gradientBatch); gradient = getBatchElement(gradientBatch, 0); return ret; } /// \brief evaluate the loss and its first and second derivative for a target and a prediction /// /// \param target target value /// \param prediction prediction, typically made by a model /// \param gradient the gradient of the loss function with respect to the prediction /// \param hessian the hessian of the loss function with respect to the prediction virtual double evalDerivative( ConstLabelReference target, ConstOutputReference prediction, OutputType& gradient,MatrixType & hessian ) const { SHARK_FEATURE_EXCEPTION_DERIVED(HAS_SECOND_DERIVATIVE); return 0.0; // dead code, prevent warning } /// \brief evaluate the loss and the derivative w.r.t. the prediction /// /// \par /// The default implementations throws an exception. /// If you overwrite this method, don't forget to set /// the flag HAS_FIRST_DERIVATIVE. /// \param target target value /// \param prediction prediction, typically made by a model /// \param gradient the gradient of the loss function with respect to the prediction virtual double evalDerivative(BatchLabelType const& target, BatchOutputType const& prediction, BatchOutputType& gradient) const { SHARK_FEATURE_EXCEPTION_DERIVED(HAS_FIRST_DERIVATIVE); return 0.0; // dead code, prevent warning } //~ /// \brief evaluate the loss and fist and second derivative w.r.t. the prediction //~ /// //~ /// \par //~ /// The default implementations throws an exception. //~ /// If you overwrite this method, don't forget to set //~ /// the flag HAS_FIRST_DERIVATIVE. //~ /// \param target target value //~ /// \param prediction prediction, typically made by a model //~ /// \param gradient the gradient of the loss function with respect to the prediction //~ /// \param hessian the hessian matrix of the loss function with respect to the prediction //~ virtual double evalDerivative( //~ LabelType const& target, //~ OutputType const& prediction, //~ OutputType& gradient, //~ MatrixType& hessian) const //~ { //~ SHARK_FEATURE_EXCEPTION_DERIVED(HAS_SECOND_DERIVATIVE); //~ return 0.0; // dead code, prevent warning //~ } /// from AbstractCost /// /// \param targets target values /// \param predictions predictions, typically made by a model double eval(Data const& targets, Data const& predictions) const{ SIZE_CHECK(predictions.numberOfElements() == targets.numberOfElements()); SIZE_CHECK(predictions.numberOfBatches() == targets.numberOfBatches()); int numBatches = (int) targets.numberOfBatches(); double error = 0; SHARK_PARALLEL_FOR(int i = 0; i < numBatches; ++i){ double batchError= eval(targets.batch(i),predictions.batch(i)); SHARK_CRITICAL_REGION{ error+=batchError; } } return error / targets.numberOfElements(); } /// \brief evaluate the loss for a target and a prediction /// /// \par /// convenience operator /// /// \param target target value /// \param prediction prediction, typically made by a model double operator () (LabelType const& target, OutputType const& prediction) const { return eval(target, prediction); } double operator () (BatchLabelType const& target, BatchOutputType const& prediction) const { return eval(target, prediction); } using base_type::operator(); }; } #endif