/*========================================================================= * * Copyright Insight Software Consortium * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef itkBackPropagationLayer_hxx #define itkBackPropagationLayer_hxx #include "itkBackPropagationLayer.h" namespace itk { namespace Statistics { template BackPropagationLayer ::BackPropagationLayer() { m_Bias = 1; } template BackPropagationLayer ::~BackPropagationLayer() {} template void BackPropagationLayer ::SetNumberOfNodes(unsigned int c) { LayerBase::SetNumberOfNodes(c); this->m_NodeInputValues.set_size(c); this->m_NodeOutputValues.set_size(c); m_InputErrorValues.set_size(c); m_OutputErrorValues.set_size(c); this->Modified(); } template void BackPropagationLayer ::SetInputValue(unsigned int i, ValueType value) { this->m_NodeInputValues[i] = value; this->Modified(); } template typename BackPropagationLayer::ValueType BackPropagationLayer ::GetInputValue(unsigned int i) const { return m_NodeInputValues[i]; } template void BackPropagationLayer ::SetOutputValue(unsigned int i, ValueType value) { m_NodeOutputValues(i) = value; this->Modified(); } template typename BackPropagationLayer::ValueType BackPropagationLayer ::GetOutputValue(unsigned int i) const { return m_NodeOutputValues(i); } template void BackPropagationLayer ::SetOutputVector(TMeasurementVector value) { m_NodeOutputValues = value.GetVnlVector(); this->Modified(); } template typename BackPropagationLayer::ValueType * BackPropagationLayer ::GetOutputVector() { return m_NodeOutputValues.data_block(); } template typename BackPropagationLayer::ValueType BackPropagationLayer ::GetInputErrorValue(unsigned int n) const { return m_InputErrorValues[n]; } template typename BackPropagationLayer::ValueType * BackPropagationLayer ::GetInputErrorVector() { return m_InputErrorValues.data_block(); } template void BackPropagationLayer ::SetInputErrorValue(ValueType v, unsigned int i) { m_InputErrorValues[i] = v; this->Modified(); } template void BackPropagationLayer ::ForwardPropagate() { typename Superclass::InputFunctionInterfaceType::Pointer inputfunction = this->GetModifiableNodeInputFunction(); typename Superclass::TransferFunctionInterfaceType::Pointer transferfunction = this->GetModifiableActivationFunction(); typename Superclass::WeightSetType::Pointer inputweightset = this->GetModifiableInputWeightSet(); //API change WeightSets are just containers const int wcols = inputweightset->GetNumberOfInputNodes(); const int wrows = inputweightset->GetNumberOfOutputNodes(); ValueType * inputvalues = inputweightset->GetInputValues(); vnl_vector prevlayeroutputvector; vnl_vector tprevlayeroutputvector; prevlayeroutputvector.set_size(wcols); tprevlayeroutputvector.set_size(wcols-1); tprevlayeroutputvector.copy_in(inputvalues); prevlayeroutputvector.update(tprevlayeroutputvector,0); prevlayeroutputvector[wcols-1]=m_Bias; vnl_diag_matrix PrevLayerOutput(prevlayeroutputvector); ValueType * weightvalues = inputweightset->GetWeightValues(); vnl_matrix weightmatrix(weightvalues,wrows, wcols); const unsigned int rows = this->m_NumberOfNodes; const unsigned int cols = this->m_InputWeightSet->GetNumberOfInputNodes(); vnl_matrix inputmatrix; inputmatrix.set_size(rows, cols); inputmatrix=weightmatrix*PrevLayerOutput; inputfunction->SetSize(cols); //include bias for (unsigned int j = 0; j < rows; j++) { vnl_vector temp_vnl(inputmatrix.get_row(j)); m_NodeInputValues.put(j, inputfunction->Evaluate(temp_vnl.data_block())); m_NodeOutputValues.put(j, transferfunction->Evaluate(m_NodeInputValues[j])); } } //overloaded to handle input layers template void BackPropagationLayer ::ForwardPropagate(TMeasurementVector samplevector) { typename Superclass::TransferFunctionInterfaceType::ConstPointer transferfunction = this->GetActivationFunction(); for (unsigned int i = 0; i < samplevector.Size(); i++) { samplevector[i] = transferfunction->Evaluate(samplevector[i]); m_NodeOutputValues.put(i, samplevector[i]); } } template void BackPropagationLayer ::BackwardPropagate(InternalVectorType errors) { const int num_nodes = this->GetNumberOfNodes(); typename Superclass::WeightSetType::Pointer inputweightset = Superclass::GetModifiableInputWeightSet(); for (unsigned int i = 0; i < errors.Size(); i++) { SetInputErrorValue(errors[i] * DActivation(GetInputValue(i)), i); } vnl_matrix inputerrormatrix(GetInputErrorVector(), num_nodes, 1); vnl_vector inputerrorvector(GetInputErrorVector(), num_nodes); vnl_matrix DW_temp(inputweightset->GetNumberOfOutputNodes(), inputweightset->GetNumberOfInputNodes()); vnl_matrix InputLayerOutput(1, inputweightset->GetNumberOfInputNodes()); vnl_matrix tempInputLayerOutput(1, inputweightset->GetNumberOfInputNodes()-1); tempInputLayerOutput.copy_in(inputweightset->GetInputValues()); InputLayerOutput.fill(0.0); for(unsigned int i=0; iGetNumberOfInputNodes()-1; i++) InputLayerOutput.put(0,i, tempInputLayerOutput.get(0,i)); //InputLayerOutput.copy_in(inputweightset->GetInputValues()); DW_temp = inputerrormatrix * InputLayerOutput; DW_temp.set_column(inputweightset->GetNumberOfInputNodes()-1,0.0); inputweightset->SetDeltaValues(DW_temp.data_block()); inputweightset->SetDeltaBValues(GetInputErrorVector()); } template void BackPropagationLayer ::SetOutputErrorValues(TTargetVector errors) { for(unsigned int i=0; iModified(); } template typename BackPropagationLayer::ValueType BackPropagationLayer ::GetOutputErrorValue(unsigned int i) const { return m_OutputErrorValues[i]; } template void BackPropagationLayer ::BackwardPropagate() { unsigned int num_nodes = this->GetNumberOfNodes(); typename Superclass::WeightSetType::Pointer outputweightset = Superclass::GetModifiableOutputWeightSet(); typename Superclass::WeightSetType::Pointer inputweightset = Superclass::GetModifiableInputWeightSet(); vnl_vector OutputLayerInput(outputweightset->GetInputValues(),num_nodes); const ValueType * deltavalues = outputweightset->GetDeltaValues(); const ValueType * weightvalues = outputweightset->GetWeightValues(); const unsigned int cols = num_nodes; const unsigned int rows = outputweightset->GetNumberOfOutputNodes(); vnl_matrix weightmatrix(weightvalues, rows, cols); vnl_matrix deltamatrix(deltavalues, rows, cols); vnl_vector deltaww; deltaww.set_size(cols); deltaww.fill(0); for(unsigned int c1=0; c1 inputerrormatrix(GetInputErrorVector(), num_nodes, 1); vnl_matrix DW_temp(inputweightset->GetNumberOfOutputNodes(), inputweightset->GetNumberOfInputNodes()); vnl_matrix InputLayerOutput(1, inputweightset->GetNumberOfInputNodes()); //InputLayerOutput.copy_in(inputweightset->GetInputValues()); vnl_matrix tempInputLayerOutput(1, inputweightset->GetNumberOfInputNodes()-1); tempInputLayerOutput.copy_in(inputweightset->GetInputValues()); InputLayerOutput.fill(0.0); for(unsigned int i=0; iGetNumberOfInputNodes()-1; i++) InputLayerOutput.put(0,i, tempInputLayerOutput.get(0,i)); DW_temp = inputerrormatrix * InputLayerOutput; DW_temp.set_column(inputweightset->GetNumberOfInputNodes()-1,0.0); inputweightset->SetDeltaValues(DW_temp.data_block()); inputweightset->SetDeltaBValues(GetInputErrorVector()); } template typename BackPropagationLayer::ValueType BackPropagationLayer ::Activation(ValueType n) { return this->m_ActivationFunction->Evaluate(n); } template typename BackPropagationLayer::ValueType BackPropagationLayer ::DActivation(ValueType n) { return this->m_ActivationFunction->EvaluateDerivative(n); } /** Print the object */ template void BackPropagationLayer ::PrintSelf( std::ostream& os, Indent indent ) const { os << indent << "BackPropagationLayer(" << this << ")" << std::endl; os << indent << "m_NodeInputValues = " << m_NodeInputValues << std::endl; os << indent << "m_NodeOutputValues = " << m_NodeOutputValues << std::endl; os << indent << "m_InputErrorValues = " << m_InputErrorValues << std::endl; os << indent << "m_OutputErrorValues = " << m_OutputErrorValues << std::endl; os << indent << "Bias = " << m_Bias << std::endl; Superclass::PrintSelf( os, indent ); } } // end namespace Statistics } // end namespace itk #endif