/* * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES) * * This file is part of Orfeo Toolbox * * https://www.orfeo-toolbox.org/ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef otbMachineLearningModelFactory_hxx #define otbMachineLearningModelFactory_hxx #include "otbMachineLearningModelFactory.h" #include "otbConfigure.h" #ifdef OTB_USE_OPENCV #include "otb_opencv_api.h" #include "otbKNearestNeighborsMachineLearningModelFactory.h" #include "otbRandomForestsMachineLearningModelFactory.h" #include "otbSVMMachineLearningModelFactory.h" #include "otbBoostMachineLearningModelFactory.h" #include "otbNeuralNetworkMachineLearningModelFactory.h" #include "otbNormalBayesMachineLearningModelFactory.h" #include "otbDecisionTreeMachineLearningModelFactory.h" #endif #ifdef OTB_USE_LIBSVM #include "otbLibSVMMachineLearningModelFactory.h" #endif #ifdef OTB_USE_SHARK #include "otbSharkRandomForestsMachineLearningModelFactory.h" #include "otbSharkKMeansMachineLearningModelFactory.h" #endif #include "itkMutexLockHolder.h" namespace otb { template typename MachineLearningModel::Pointer MachineLearningModelFactory::CreateMachineLearningModel(const std::string& path, FileModeType mode) { RegisterBuiltInFactories(); std::list possibleMachineLearningModel; std::list allobjects = itk::ObjectFactoryBase::CreateAllInstance("otbMachineLearningModel"); for (std::list::iterator i = allobjects.begin(); i != allobjects.end(); ++i) { MachineLearningModel* io = dynamic_cast*>(i->GetPointer()); if (io) { possibleMachineLearningModel.push_back(io); } else { std::cerr << "Error MachineLearningModel Factory did not return an MachineLearningModel: " << (*i)->GetNameOfClass() << std::endl; } } for (typename std::list::iterator k = possibleMachineLearningModel.begin(); k != possibleMachineLearningModel.end(); ++k) { if (mode == ReadMode) { if ((*k)->CanReadFile(path)) { return *k; } } else if (mode == WriteMode) { if ((*k)->CanWriteFile(path)) { return *k; } } } return nullptr; } template void MachineLearningModelFactory::RegisterBuiltInFactories() { itk::MutexLockHolder lockHolder(mutex); #ifdef OTB_USE_LIBSVM RegisterFactory(LibSVMMachineLearningModelFactory::New()); #endif #ifdef OTB_USE_SHARK RegisterFactory(SharkRandomForestsMachineLearningModelFactory::New()); RegisterFactory(SharkKMeansMachineLearningModelFactory::New()); #endif #ifdef OTB_USE_OPENCV RegisterFactory(RandomForestsMachineLearningModelFactory::New()); RegisterFactory(SVMMachineLearningModelFactory::New()); RegisterFactory(BoostMachineLearningModelFactory::New()); RegisterFactory(NeuralNetworkMachineLearningModelFactory::New()); RegisterFactory(NormalBayesMachineLearningModelFactory::New()); RegisterFactory(DecisionTreeMachineLearningModelFactory::New()); RegisterFactory(KNearestNeighborsMachineLearningModelFactory::New()); #endif } template void MachineLearningModelFactory::RegisterFactory(itk::ObjectFactoryBase* factory) { // Unregister any previously registered factory of the same class // Might be more intensive but static bool is not an option due to // ld error. itk::ObjectFactoryBase::UnRegisterFactory(factory); itk::ObjectFactoryBase::RegisterFactory(factory); } template void MachineLearningModelFactory::CleanFactories() { itk::MutexLockHolder lockHolder(mutex); std::list factories = itk::ObjectFactoryBase::GetRegisteredFactories(); std::list::iterator itFac; for (itFac = factories.begin(); itFac != factories.end(); ++itFac) { #ifdef OTB_USE_LIBSVM LibSVMMachineLearningModelFactory* libsvmFactory = dynamic_cast*>(*itFac); if (libsvmFactory) { itk::ObjectFactoryBase::UnRegisterFactory(libsvmFactory); continue; } #endif #ifdef OTB_USE_SHARK SharkRandomForestsMachineLearningModelFactory* sharkRFFactory = dynamic_cast*>(*itFac); if (sharkRFFactory) { itk::ObjectFactoryBase::UnRegisterFactory(sharkRFFactory); continue; } SharkKMeansMachineLearningModelFactory* sharkKMeansFactory = dynamic_cast*>(*itFac); if (sharkKMeansFactory) { itk::ObjectFactoryBase::UnRegisterFactory(sharkKMeansFactory); continue; } #endif #ifdef OTB_USE_OPENCV // RandomForest RandomForestsMachineLearningModelFactory* rfFactory = dynamic_cast*>(*itFac); if (rfFactory) { itk::ObjectFactoryBase::UnRegisterFactory(rfFactory); continue; } // SVM SVMMachineLearningModelFactory* svmFactory = dynamic_cast*>(*itFac); if (svmFactory) { itk::ObjectFactoryBase::UnRegisterFactory(svmFactory); continue; } // Boost BoostMachineLearningModelFactory* boostFactory = dynamic_cast*>(*itFac); if (boostFactory) { itk::ObjectFactoryBase::UnRegisterFactory(boostFactory); continue; } // ANN NeuralNetworkMachineLearningModelFactory* annFactory = dynamic_cast*>(*itFac); if (annFactory) { itk::ObjectFactoryBase::UnRegisterFactory(annFactory); continue; } // Bayes NormalBayesMachineLearningModelFactory* bayesFactory = dynamic_cast*>(*itFac); if (bayesFactory) { itk::ObjectFactoryBase::UnRegisterFactory(bayesFactory); continue; } // Decision Tree DecisionTreeMachineLearningModelFactory* dtFactory = dynamic_cast*>(*itFac); if (dtFactory) { itk::ObjectFactoryBase::UnRegisterFactory(dtFactory); continue; } // KNN KNearestNeighborsMachineLearningModelFactory* knnFactory = dynamic_cast*>(*itFac); if (knnFactory) { itk::ObjectFactoryBase::UnRegisterFactory(knnFactory); continue; } #endif } } } // end namespace otb #endif