/*========================================================================= * * Copyright Insight Software Consortium * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkKdTreeGenerator_hxx #define itkKdTreeGenerator_hxx #include "itkKdTreeGenerator.h" namespace itk { namespace Statistics { template< typename TSample > KdTreeGenerator< TSample > ::KdTreeGenerator() { m_SourceSample = ITK_NULLPTR; m_BucketSize = 16; m_Subsample = SubsampleType::New(); m_MeasurementVectorSize = 0; } template< typename TSample > void KdTreeGenerator< TSample > ::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); os << indent << "Source Sample: "; if ( m_SourceSample != ITK_NULLPTR ) { os << m_SourceSample << std::endl; } else { os << "not set." << std::endl; } os << indent << "Bucket Size: " << m_BucketSize << std::endl; os << indent << "MeasurementVectorSize: " << m_MeasurementVectorSize << std::endl; } template< typename TSample > void KdTreeGenerator< TSample > ::SetSample(TSample *sample) { m_SourceSample = sample; m_Subsample->SetSample(sample); m_Subsample->InitializeWithAllInstances(); m_MeasurementVectorSize = sample->GetMeasurementVectorSize(); NumericTraits::SetLength(m_TempLowerBound, m_MeasurementVectorSize); NumericTraits::SetLength(m_TempUpperBound, m_MeasurementVectorSize); NumericTraits::SetLength(m_TempMean, m_MeasurementVectorSize); } template< typename TSample > void KdTreeGenerator< TSample > ::SetBucketSize(unsigned int size) { m_BucketSize = size; } template< typename TSample > void KdTreeGenerator< TSample > ::GenerateData() { if ( m_SourceSample == ITK_NULLPTR ) { return; } if ( m_Tree.IsNull() ) { m_Tree = KdTreeType::New(); m_Tree->SetSample(m_SourceSample); m_Tree->SetBucketSize(m_BucketSize); } SubsamplePointer subsample = this->GetSubsample(); // Sanity check. Verify that the subsample has measurement vectors of the // same length as the sample generated by the tree. if ( this->GetMeasurementVectorSize() != subsample->GetMeasurementVectorSize() ) { itkExceptionMacro(<< "Measurement Vector Length mismatch"); } MeasurementVectorType lowerBound; NumericTraits::SetLength(lowerBound, m_MeasurementVectorSize); MeasurementVectorType upperBound; NumericTraits::SetLength(upperBound, m_MeasurementVectorSize); for ( unsigned int d = 0; d < m_MeasurementVectorSize; d++ ) { lowerBound[d] = NumericTraits< MeasurementType >::NonpositiveMin(); upperBound[d] = NumericTraits< MeasurementType >::max(); } KdTreeNodeType *root = this->GenerateTreeLoop(0, m_Subsample->Size(), lowerBound, upperBound, 0); m_Tree->SetRoot(root); } template< typename TSample > inline typename KdTreeGenerator< TSample >::KdTreeNodeType * KdTreeGenerator< TSample > ::GenerateNonterminalNode(unsigned int beginIndex, unsigned int endIndex, MeasurementVectorType & lowerBound, MeasurementVectorType & upperBound, unsigned int level) { typedef typename KdTreeType::KdTreeNodeType NodeType; MeasurementType dimensionLowerBound; MeasurementType dimensionUpperBound; MeasurementType partitionValue; unsigned int partitionDimension = 0; unsigned int i; MeasurementType spread; MeasurementType maxSpread; unsigned int medianIndex; SubsamplePointer subsample = this->GetSubsample(); // find most widely spread dimension Algorithm::FindSampleBoundAndMean< SubsampleType >(subsample, beginIndex, endIndex, m_TempLowerBound, m_TempUpperBound, m_TempMean); maxSpread = NumericTraits< MeasurementType >::NonpositiveMin(); for ( i = 0; i < m_MeasurementVectorSize; i++ ) { spread = m_TempUpperBound[i] - m_TempLowerBound[i]; if ( spread >= maxSpread ) { maxSpread = spread; partitionDimension = i; } } medianIndex = ( endIndex - beginIndex ) / 2; // // Find the medial element by using the NthElement function // based on the STL implementation of the QuickSelect algorithm. // partitionValue = Algorithm::NthElement< SubsampleType >(m_Subsample, partitionDimension, beginIndex, endIndex, medianIndex); medianIndex += beginIndex; // save bounds for cutting dimension dimensionLowerBound = lowerBound[partitionDimension]; dimensionUpperBound = upperBound[partitionDimension]; upperBound[partitionDimension] = partitionValue; const unsigned int beginLeftIndex = beginIndex; const unsigned int endLeftIndex = medianIndex; NodeType * left = GenerateTreeLoop(beginLeftIndex, endLeftIndex, lowerBound, upperBound, level + 1); upperBound[partitionDimension] = dimensionUpperBound; lowerBound[partitionDimension] = partitionValue; const unsigned int beginRightIndex = medianIndex + 1; const unsigned int endRightIndex = endIndex; NodeType * right = GenerateTreeLoop(beginRightIndex, endRightIndex, lowerBound, upperBound, level + 1); lowerBound[partitionDimension] = dimensionLowerBound; typedef KdTreeNonterminalNode< TSample > KdTreeNonterminalNodeType; KdTreeNonterminalNodeType *nonTerminalNode = new KdTreeNonterminalNodeType(partitionDimension, partitionValue, left, right); nonTerminalNode->AddInstanceIdentifier( subsample->GetInstanceIdentifier(medianIndex) ); return nonTerminalNode; } template< typename TSample > inline typename KdTreeGenerator< TSample >::KdTreeNodeType * KdTreeGenerator< TSample > ::GenerateTreeLoop(unsigned int beginIndex, unsigned int endIndex, MeasurementVectorType & lowerBound, MeasurementVectorType & upperBound, unsigned int level) { if ( endIndex - beginIndex <= m_BucketSize ) { // numberOfInstances small, make a terminal node if ( endIndex == beginIndex ) { // return the pointer to empty terminal node return m_Tree->GetEmptyTerminalNode(); } else { KdTreeTerminalNode< TSample > *ptr = new KdTreeTerminalNode< TSample >(); for ( unsigned int j = beginIndex; j < endIndex; j++ ) { ptr->AddInstanceIdentifier( this->GetSubsample()->GetInstanceIdentifier(j) ); } // return a terminal node return ptr; } } else { return this->GenerateNonterminalNode(beginIndex, endIndex, lowerBound, upperBound, level + 1); } } } // end of namespace Statistics } // end of namespace itk #endif