//===========================================================================
/*!
*
*
* \brief Cart Classifier
*
*
*
* \author K. N. Hansen, J. Kremer
* \date 2012
*
*
* \par Copyright 1995-2017 Shark Development Team
*
*
* This file is part of Shark.
*
*
* Shark is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Shark is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with Shark. If not, see .
*
*/
//===========================================================================
#ifndef SHARK_MODELS_TREES_CARTree_H
#define SHARK_MODELS_TREES_CARTree_H
#include
#include
namespace shark {
///
/// \brief Classification and Regression Tree.
///
/// \par
/// The CARTree predicts a class label using a decision tree
template
class CARTree : public AbstractModel
{
private:
typedef AbstractModel base_type;
public:
typedef typename base_type::BatchInputType BatchInputType;
typedef typename base_type::BatchOutputType BatchOutputType;
struct Node{
std::size_t attributeIndex;
double attributeValue;
std::size_t leftId;
std::size_t rightIdOrIndex;
template
void serialize(Archive & ar, const unsigned int version){
ar & attributeIndex;
ar & attributeValue;
ar & leftId;
ar & rightIdOrIndex;///< either id of right node or index to label array
}
};
typedef std::vector TreeType;
/// Constructor
CARTree(): m_inputDimension(0){}
CARTree(std::size_t inputDimension, Shape const& outputShape)
: m_inputDimension(inputDimension)
, m_outputShape(outputShape){}
/// \brief From INameable: return the class name.
std::string name() const
{ return "CARTree"; }
boost::shared_ptr createState() const{
return boost::shared_ptr(new EmptyState());
}
using base_type::eval;
/// \brief Evaluate the Tree on a batch of patterns
void eval(BatchInputType const& patterns, BatchOutputType & outputs) const{
std::size_t numPatterns = patterns.size1();
//evaluate the first pattern alone and create the batch output from that
LabelType const& firstResult = evalPattern(row(patterns,0));
outputs = Batch::createBatch(firstResult,numPatterns);
getBatchElement(outputs,0) = firstResult;
//evaluate the rest
for(std::size_t i = 0; i != numPatterns; ++i){
getBatchElement(outputs,i) = evalPattern(row(patterns,i));
}
}
void eval(BatchInputType const& patterns, BatchOutputType & outputs, State& state) const{
eval(patterns,outputs);
}
/// \brief Evaluate the Tree on a single pattern
void eval(RealVector const& pattern, LabelType& output){
output = evalPattern(pattern);
}
/// \brief The model does not have any parameters.
std::size_t numberOfParameters() const{
return 0;
}
/// \brief The model does not have any parameters.
RealVector parameterVector() const {
return RealVector();
}
/// \brief The model does not have any parameters.
void setParameterVector(RealVector const& param) {
SHARK_ASSERT(param.size() == 0);
}
/// from ISerializable, reads a model from an archive
void read(InArchive& archive){
archive >> m_tree;
archive >> m_labels;
archive >> m_inputDimension;
archive >> m_outputShape;
}
/// from ISerializable, writes a model to an archive
void write(OutArchive& archive) const {
archive << m_tree;
archive << m_labels;
archive << m_inputDimension;
archive << m_outputShape;
}
//Count how often attributes are used
UIntVector countAttributes() const {
SHARK_ASSERT(m_inputDimension > 0);
UIntVector r(m_inputDimension, 0);
for(auto it = m_tree.begin(); it != m_tree.end(); ++it) {
if(it->leftId != 0) { // not a label
r(it->attributeIndex)++;
}
}
return r;
}
///Return input dimension
Shape inputShape() const {
return m_inputDimension;
}
Shape outputShape() const{
return m_outputShape;
}
////////////////////////////////
/////Tree Construction routines
///////////////////////////////
std::size_t numberOfNodes() const{
return m_tree.size();
}
/// \brief Returns the node with id nodeId
Node& getNode(std::size_t nodeId){
SIZE_CHECK(nodeId < m_tree.size());
return m_tree[nodeId];
}
/// \brief Returns the node with id nodeId
Node const& getNode(std::size_t nodeId)const{
SIZE_CHECK(nodeId < m_tree.size());
return m_tree[nodeId];
}
LabelType const& getLabel(std::size_t nodeId)const{
SIZE_CHECK(nodeId < m_tree.size());
return m_labels[m_tree[nodeId].rightIdOrIndex];
}
/// \brief Creates and returns an untyped root node (neither internal, nor leaf node)
Node& createRoot(){
m_tree.clear();
Node root;
root.leftId = 0;
root.rightIdOrIndex = 0;
m_tree.push_back(root);
return m_tree[0];
}
///\brief Transforms an untyped node (no child, no internal node) into an internal node
///
/// This creates already the two childs of the node, which are untyped.
Node& transformInternalNode(std::size_t nodeId, std::size_t attributeIndex, double attributeValue) {
// ids for new child nodes
int nodeIdLeft = m_tree.size();
int nodeIdRight = m_tree.size() + 1;
//create new child nodes
Node leftChild;
leftChild.leftId = 0;
leftChild.rightIdOrIndex = 0;
Node rightChild;
rightChild.leftId = 0;
rightChild.rightIdOrIndex = 0;
m_tree.push_back(leftChild);
m_tree.push_back(rightChild);
// connect the parent node with its two childs
m_tree[nodeId].leftId = nodeIdLeft;
m_tree[nodeId].rightIdOrIndex = nodeIdRight;
m_tree[nodeId].attributeIndex = attributeIndex;
m_tree[nodeId].attributeValue = attributeValue;
return m_tree[nodeId];
}
///\brief Transforms a node (no leaf) into a leaf node and inserts the appropriate label
///
/// If the node was an internal node before, its connections get removed and the childs
/// are not reachable any more. Calling a reorder routine like reorderBFS() will get rid of those
/// nodes.
Node& transformLeafNode(std::size_t nodeId, LabelType const& label){
Node& node = m_tree[nodeId];
node.attributeIndex = 0;
node. attributeValue = 0.0;
node.leftId = 0;
node.rightIdOrIndex = m_labels.size();
m_labels.push_back(label);
return node;
}
/// \brief Reorders a tree into a breath-first-ordering
///
/// This function call will remove all unreachable subtrees while reordering
/// the nodes by their depth in the tree, i.e. first comes the root, the the children
/// of the root, than their children, etc.
void reorderBFS(){
TreeType reordered_tree;
reordered_tree.reserve(m_tree.size());
std::deque bfs_queue;
bfs_queue.push_back(0);
std::size_t nodeId = 0; //running id of the next node to insert
while(!bfs_queue.empty()){
Node const& node = getNode(bfs_queue.front());
bfs_queue.pop_front();
//check leaf
if(!node.leftId == 0){
reordered_tree.push_back(node);
}else{
reordered_tree.push_back(node);
reordered_tree.back().leftId = nodeId+1;
reordered_tree.back().rightIdOrIndex = nodeId+2;
nodeId += 2;
bfs_queue.push_back(node.leftId);
bfs_queue.push_back(node.rightIdOrIndex);
}
}
//overwrite old tree with pruned tree
m_tree = std::move(reordered_tree);
}
/// Find the leaf of the tree for a sample
template
std::size_t findLeaf(Vector const& pattern) const{
std::size_t nodeId = 0;
while(m_tree[nodeId].leftId != 0){
if(pattern[m_tree[nodeId].attributeIndex] <= m_tree[nodeId].attributeValue){
//Branch on left node
nodeId = m_tree[nodeId].leftId;
}else{
//Branch on right node
nodeId = m_tree[nodeId].rightIdOrIndex;
}
}
return nodeId;
}
private:
/// tree of the model
TreeType m_tree;
std::vector m_labels;
///Number of attributes (set by trainer)
std::size_t m_inputDimension;
Shape m_outputShape;
/// Evaluate the CART tree on a single sample
template
LabelType const& evalPattern(Vector const& pattern) const{
auto nodeId = findLeaf(pattern);
return m_labels[m_tree[nodeId].rightIdOrIndex];
}
};
}
#endif