/* * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES) * * This file is part of Orfeo Toolbox * * https://www.orfeo-toolbox.org/ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef otbTrainImagesBase_h #define otbTrainImagesBase_h #include "otbVectorDataFileWriter.h" #include "otbWrapperCompositeApplication.h" #include "otbWrapperApplicationFactory.h" #include "otbStatisticsXMLFileWriter.h" #include "otbImageToEnvelopeVectorDataFilter.h" #include "otbSamplingRateCalculator.h" #include "otbOGRDataToSamplePositionFilter.h" #include namespace otb { namespace Wrapper { /** \class TrainImagesBase * \brief Base class for the TrainImagesClassifier * * This class intends to hold common input/output parameters and * composite application connection for both supervised and unsupervised * model training. * * \ingroup OTBAppClassification */ class TrainImagesBase : public CompositeApplication { public: /** Standard class typedefs. */ typedef TrainImagesBase Self; typedef CompositeApplication Superclass; typedef itk::SmartPointer Pointer; typedef itk::SmartPointer ConstPointer; /** Standard macro */ itkTypeMacro(TrainImagesBase, Superclass); /** filters typedefs*/ typedef otb::OGRDataToSamplePositionFilter PeriodicSamplerType; typedef otb::SamplingRateCalculator::MapRateType MapRateType; protected: typedef enum { CLASS, GEOMETRIC } SamplingStrategy; struct SamplingRates; class TrainFileNamesHandler; /** * Initialize all the input and output parameter used for the train images */ void InitIO(); /** * Initialize sampling related application and parameters */ void InitSampling(); void ShareSamplingParameters(); void ConnectSamplingParameters(); void InitClassification(); void ShareClassificationParams(); void ConnectClassificationParams(); /** * Compute polygon statistics given provided strategy with PolygonClassStatistics class * \param imageList list of input images * \param vectorFileNames list of input vector file names * \param statisticsFileNames list of out */ void ComputePolygonStatistics(FloatVectorImageListType* imageList, const std::vector& vectorFileNames, const std::vector& statisticsFileNames); /** * Compute final maximum training and validation * \param dedicatedValidation * \return SamplingRates final maximum training and final maximum validation */ SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation); /** * Compute rates using MultiImageSamplingRate application * \param statisticsFileNames * \param ratesFileName * \param maximum final maximum value computed by ComputeFinalMaximumSamplingRates * \sa ComputeFinalMaximumSamplingRates */ void ComputeSamplingRate(const std::vector& statisticsFileNames, const std::string& ratesFileName, long maximum); /** * Train the model with training and optional validation data samples * \param imageList list of input images * \param sampleTrainFileNames files names of the training samples * \param sampleValidationFileNames file names of the validation sample */ void TrainModel(FloatVectorImageListType* imageList, const std::vector& sampleTrainFileNames, const std::vector& sampleValidationFileNames); /** * Select samples by class or by geographic strategy * \param image * \param vectorFileName * \param sampleFileName * \param statisticsFileName * \param ratesFileName * \param strategy */ void SelectAndExtractSamples(FloatVectorImageType* image, std::string vectorFileName, std::string sampleFileName, std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy, std::string selectedField = ""); /** * Select and extract samples with the SampleSelection and SampleExtraction application. * \param fileNames * \param imageList * \param vectorFileNames * \param strategy the strategy used for selection (by class or with geometry) * \param selectedFieldName */ void SelectAndExtractTrainSamples(const TrainFileNamesHandler& fileNames, FloatVectorImageListType* imageList, std::vector vectorFileNames, SamplingStrategy strategy, std::string selectedFieldName = ""); /** * Function used to select validation samples based on a defined strategy (geometric in unsupervised mode) * and extract them. With dedicated validation the 'by class' sampling strategy and statistics are used. * Otherwise this function split training to validation samples corresponding to sample.vtr percentage. * or do nothing if this percentage is == 0 * \param fileNames * \param imageList * \param validationVectorFileList optional validation vector file for each images */ void SelectAndExtractValidationSamples(const TrainFileNamesHandler& fileNames, FloatVectorImageListType* imageList, const std::vector& validationVectorFileList = std::vector()); /** * Function used to split all training samples from all images in a set of training and validation. * \param fileNames * \param imageList * \sa SplitTrainingAndValidationSamples */ void SplitTrainingToValidationSamples(const TrainFileNamesHandler& fileNames, FloatVectorImageListType* imageList); private: /** * Function used to split training samples in set of training and validation. * \param image input image * \param sampleFileName the input sample file name * \param sampleTrainFileName the input training file name * \param sampleValidFileName the input validation file name * \param ratesTrainFileName the rates file name */ void SplitTrainingAndValidationSamples(FloatVectorImageType* image, std::string sampleFileName, std::string sampleTrainFileName, std::string sampleValidFileName, std::string ratesTrainFileName); protected: struct SamplingRates { long int fmt; long int fmv; }; /** * \class TrainFileNamesHandler * This class is used to store file names requires for the application's input and output. * And to clear temporary files generated by the applications * \ingroup OTBAppClassification */ class TrainFileNamesHandler { public: void CreateTemporaryFileNames(std::string outModel, size_t nbInputs, bool dedicatedValidation) { if (dedicatedValidation) { rateTrainOut = outModel + "_ratesTrain.csv"; } else { rateTrainOut = outModel + "_rates.csv"; } rateValidOut = outModel + "_ratesValid.csv"; for (unsigned int i = 0; i < nbInputs; i++) { std::ostringstream oss; oss << i + 1; std::string strIndex(oss.str()); if (dedicatedValidation) { polyStatTrainOutputs.push_back(outModel + "_statsTrain_" + strIndex + ".xml"); polyStatValidOutputs.push_back(outModel + "_statsValid_" + strIndex + ".xml"); ratesTrainOutputs.push_back(outModel + "_ratesTrain_" + strIndex + ".csv"); ratesValidOutputs.push_back(outModel + "_ratesValid_" + strIndex + ".csv"); sampleOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp"); } else { polyStatTrainOutputs.push_back(outModel + "_stats_" + strIndex + ".xml"); ratesTrainOutputs.push_back(outModel + "_rates_" + strIndex + ".csv"); sampleOutputs.push_back(outModel + "_samples_" + strIndex + ".shp"); } sampleTrainOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp"); sampleValidOutputs.push_back(outModel + "_samplesValid_" + strIndex + ".shp"); } } void clear() { for (unsigned int i = 0; i < polyStatTrainOutputs.size(); i++) RemoveFile(polyStatTrainOutputs[i]); for (unsigned int i = 0; i < polyStatValidOutputs.size(); i++) RemoveFile(polyStatValidOutputs[i]); for (unsigned int i = 0; i < ratesTrainOutputs.size(); i++) RemoveFile(ratesTrainOutputs[i]); for (unsigned int i = 0; i < ratesValidOutputs.size(); i++) RemoveFile(ratesValidOutputs[i]); for (unsigned int i = 0; i < sampleOutputs.size(); i++) RemoveFile(sampleOutputs[i]); for (unsigned int i = 0; i < sampleTrainOutputs.size(); i++) RemoveFile(sampleTrainOutputs[i]); for (unsigned int i = 0; i < sampleValidOutputs.size(); i++) RemoveFile(sampleValidOutputs[i]); for (unsigned int i = 0; i < tmpVectorFileList.size(); i++) RemoveFile(tmpVectorFileList[i]); } public: std::vector polyStatTrainOutputs; std::vector polyStatValidOutputs; std::vector ratesTrainOutputs; std::vector ratesValidOutputs; std::vector sampleOutputs; std::vector sampleTrainOutputs; std::vector sampleValidOutputs; std::vector tmpVectorFileList; std::string rateValidOut; std::string rateTrainOut; private: bool RemoveFile(std::string& filePath) { bool res = true; if (itksys::SystemTools::FileExists(filePath)) { size_t posExt = filePath.rfind('.'); if (posExt != std::string::npos && filePath.compare(posExt, std::string::npos, ".shp") == 0) { std::string shxPath = filePath.substr(0, posExt) + std::string(".shx"); std::string dbfPath = filePath.substr(0, posExt) + std::string(".dbf"); std::string prjPath = filePath.substr(0, posExt) + std::string(".prj"); RemoveFile(shxPath); RemoveFile(dbfPath); RemoveFile(prjPath); } res = itksys::SystemTools::RemoveFile(filePath); if (!res) { // otbAppLogINFO( <<"Unable to remove file "<