/* * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES) * * This file is part of Orfeo Toolbox * * https://www.orfeo-toolbox.org/ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef otbSharkKMeansMachineLearningModel_hxx #define otbSharkKMeansMachineLearningModel_hxx #include #include "boost/make_shared.hpp" #include "itkMacro.h" #include "otbSharkKMeansMachineLearningModel.h" #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wshadow" #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Woverloaded-virtual" #pragma GCC diagnostic ignored "-Wignored-qualifiers" #endif #include "otb_shark.h" #include "otbSharkUtils.h" #include "shark/Algorithms/KMeans.h" //k-means algorithm #include "shark/Models/Clustering/HardClusteringModel.h" #include "shark/Models/Clustering/SoftClusteringModel.h" #include //load the csv file #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif namespace otb { template SharkKMeansMachineLearningModel::SharkKMeansMachineLearningModel() : m_K(2), m_MaximumNumberOfIterations(10) { // Default set HardClusteringModel this->m_ConfidenceIndex = true; m_ClusteringModel = boost::make_shared(&m_Centroids); } template SharkKMeansMachineLearningModel::~SharkKMeansMachineLearningModel() { } /** Train the machine learning model */ template void SharkKMeansMachineLearningModel::Train() { // Parse input data and convert to Shark Data std::vector vector_data; otb::Shark::ListSampleToSharkVector(this->GetInputListSample(), vector_data); shark::Data data = shark::createDataFromRange(vector_data); // Use a Hard Clustering Model for classification shark::kMeans(data, m_K, m_Centroids, m_MaximumNumberOfIterations); m_ClusteringModel = boost::make_shared(&m_Centroids); } template typename SharkKMeansMachineLearningModel::TargetSampleType SharkKMeansMachineLearningModel::DoPredict(const InputSampleType& value, ConfidenceValueType* quality, ProbaSampleType* proba) const { shark::RealVector data(value.Size()); for (size_t i = 0; i < value.Size(); i++) { data.push_back(value[i]); } // Change quality measurement only if SoftClustering or other clustering method is used. if (quality != nullptr) { // unsigned int probas = (*m_ClusteringModel)( data ); (*quality) = ConfidenceValueType(1.); } if (proba != nullptr) { if (!this->m_ProbaIndex) { itkExceptionMacro("Probability per class not available for this classifier !"); } } TargetSampleType target; ClusteringOutputType predictedValue = (*m_ClusteringModel)(data); target[0] = static_cast(predictedValue); return target; } template void SharkKMeansMachineLearningModel::DoPredictBatch(const InputListSampleType* input, const unsigned int& startIndex, const unsigned int& size, TargetListSampleType* targets, ConfidenceListSampleType* quality, ProbaListSampleType* proba) const { // Perform check on input values assert(input != nullptr); assert(targets != nullptr); // input list sample and target list sample should be initialized and without assert(input->Size() == targets->Size() && "Input sample list and target label list do not have the same size."); assert(((quality == nullptr) || (quality->Size() == input->Size())) && "Quality samples list is not null and does not have the same size as input samples list"); if (startIndex + size > input->Size()) { itkExceptionMacro(<< "requested range [" << startIndex << ", " << startIndex + size << "[ partially outside input sample list range.[0," << input->Size() << "["); } // Convert input list of features to shark data format std::vector features; otb::Shark::ListSampleRangeToSharkVector(input, features, startIndex, size); shark::Data inputSamples = shark::createDataFromRange(features); shark::Data clusters; try { clusters = (*m_ClusteringModel)(inputSamples); } catch (...) { itkExceptionMacro( "Failed to run clustering classification. " "The number of features of input samples and the model could differ."); } unsigned int id = startIndex; for (const auto& p : clusters.elements()) { TargetSampleType target; target[0] = static_cast(p); targets->SetMeasurementVector(id, target); ++id; } // Change quality measurement only if SoftClustering or other clustering method is used. if (quality != nullptr) { for (unsigned int qid = startIndex; qid < startIndex + size; ++qid) { quality->SetMeasurementVector(qid, static_cast(1.)); } } if (proba != nullptr && !this->m_ProbaIndex) { itkExceptionMacro("Probability per class not available for this classifier !"); } } template void SharkKMeansMachineLearningModel::Save(const std::string& filename, const std::string& itkNotUsed(name)) { std::ofstream ofs(filename); if (!ofs) { itkExceptionMacro(<< "Error opening " << filename.c_str()); } ofs << "#" << m_ClusteringModel->name() << std::endl; shark::TextOutArchive oa(ofs); m_ClusteringModel->save(oa, 1); } template void SharkKMeansMachineLearningModel::Load(const std::string& filename, const std::string& itkNotUsed(name)) { m_CanRead = false; std::ifstream ifs(filename); if (ifs.good()) { // Check if first line contains model name std::string line; std::getline(ifs, line); m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos; } if (!m_CanRead) return; shark::TextInArchive ia(ifs); m_ClusteringModel->load(ia, 0); ifs.close(); } template bool SharkKMeansMachineLearningModel::CanReadFile(const std::string& file) { try { m_CanRead = true; this->Load(file); } catch (...) { return false; } return m_CanRead; } template bool SharkKMeansMachineLearningModel::CanWriteFile(const std::string& itkNotUsed(file)) { return true; } template void SharkKMeansMachineLearningModel::ExportCentroids(const std::string& filename) { shark::exportCSV(m_Centroids.centroids(), filename, ' '); } template void SharkKMeansMachineLearningModel::PrintSelf(std::ostream& os, itk::Indent indent) const { // Call superclass implementation Superclass::PrintSelf(os, indent); } } // end namespace otb #endif