///////////////////////////////////////////////////////////////////////////////
//
//  Copyright (2008) Alexander Stukowski
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2 of the License, or
//  (at your option) any later version.
//
//  OVITO is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
///////////////////////////////////////////////////////////////////////////////

#include <core/Core.h>
#include <core/viewport/Viewport.h>
#include <core/viewport/ViewportManager.h>
#include <core/scene/animation/controller/StandardControllers.h>
#include <core/scene/animation/AnimManager.h>
#include <core/gui/properties/FloatControllerUI.h>
#include <core/gui/properties/BooleanPropertyUI.h>
#include <core/gui/properties/SubObjectParameterUI.h>
#include <core/utilities/ProgressIndicator.h>

#include <atomviz/atoms/datachannels/OrientationDataChannel.h>
#include <atomviz/utils/NearestNeighborList.h>
#include <atomviz/utils/OnTheFlyNeighborList.h>

#include <gsl/gsl_multifit.h>

#include "CalculateExtrinsicStrainModifier.h"

namespace CrystalAnalysis {

IMPLEMENT_SERIALIZABLE_PLUGIN_CLASS(CalculateExtrinsicStrainModifier, AtomsObjectAnalyzerBase)
DEFINE_REFERENCE_FIELD(CalculateExtrinsicStrainModifier, DeformationGradientDataChannel, "DeformationGradientChannel", deformationGradientChannel)
DEFINE_REFERENCE_FIELD(CalculateExtrinsicStrainModifier, DataChannel, "StrainTensorChannel", strainTensorChannel)
DEFINE_REFERENCE_FIELD(CalculateExtrinsicStrainModifier, DataChannel, "HydrostaticStrainChannel", hydrostaticStrainChannel)
DEFINE_REFERENCE_FIELD(CalculateExtrinsicStrainModifier, DataChannel, "ShearStrainChannel", shearStrainChannel)

/******************************************************************************
* Constructs the modifier object.
******************************************************************************/
CalculateExtrinsicStrainModifier::CalculateExtrinsicStrainModifier(bool isLoading)
	: AtomsObjectAnalyzerBase(isLoading)
{
	INIT_PROPERTY_FIELD(CalculateExtrinsicStrainModifier, deformationGradientChannel);
	INIT_PROPERTY_FIELD(CalculateExtrinsicStrainModifier, strainTensorChannel);
	INIT_PROPERTY_FIELD(CalculateExtrinsicStrainModifier, hydrostaticStrainChannel);
	INIT_PROPERTY_FIELD(CalculateExtrinsicStrainModifier, shearStrainChannel);
	if(!isLoading) {
		deformationGradientChannel = new DeformationGradientDataChannel(DataChannel::DeformationGradientChannel);
		strainTensorChannel = new DataChannel(DataChannel::StrainTensorChannel);
		hydrostaticStrainChannel = new DataChannel(qMetaTypeId<FloatType>(), sizeof(FloatType), 1);
		hydrostaticStrainChannel->setName(tr("Hydrostatic Strain"));
		shearStrainChannel = new DataChannel(qMetaTypeId<FloatType>(), sizeof(FloatType), 1);
		shearStrainChannel->setName(tr("Shear Strain"));
	}
}

/******************************************************************************
* Applies the previously calculated analysis results to the atoms object.
******************************************************************************/
EvaluationStatus CalculateExtrinsicStrainModifier::applyResult(TimeTicks time, TimeInterval& validityInterval)
{
	if(deformationGradients() == NULL || strainTensors() == NULL)
		throw Exception(tr("No deformation analysis results available."));

	// Check if it is still valid.
	if(input()->atomsCount() != deformationGradients()->size())
		throw Exception(tr("Number of atoms of input object has changed. Analysis results became invalid."));

	CloneHelper cloneHelper;

	// Create a copy of the internal buffer channels and assign them to the output AtomsObject.

	DataChannel::SmartPtr gradClone = cloneHelper.cloneObject(deformationGradients(), true);
	output()->replaceDataChannel(outputStandardChannel(DataChannel::DeformationGradientChannel), gradClone.get());

	DataChannel::SmartPtr strainClone = cloneHelper.cloneObject(strainTensors(), true);
	output()->replaceDataChannel(outputStandardChannel(DataChannel::StrainTensorChannel), strainClone.get());

	DataChannel::SmartPtr hydrostaticStrainClone = cloneHelper.cloneObject(hydrostaticStrains(), true);
	output()->insertDataChannel(hydrostaticStrainClone);

	DataChannel::SmartPtr shearStrainClone = cloneHelper.cloneObject(shearStrains(), true);
	output()->insertDataChannel(shearStrainClone);

	return EvaluationStatus();
}

/******************************************************************************
* This is the actual analysis method.
******************************************************************************/
EvaluationStatus CalculateExtrinsicStrainModifier::doAnalysis(TimeTicks time, bool suppressDialogs)
{
	// Make sure that the input object has been displacement vectors.
	expectStandardChannel(DataChannel::DisplacementChannel);
	if(calculate(input(), suppressDialogs))
		return EvaluationStatus();
	else
		return EvaluationStatus(EvaluationStatus::EVALUATION_ERROR, tr("Calculation has been canceled by the user."));
}

/******************************************************************************
* This structure combines all vectors and matrices needed to pass the data
* to the GSL least-square fit routine.
******************************************************************************/
struct LeastSquareWorkspace
{
	LeastSquareWorkspace(int numNeighbors) {
		lsqFitWorkspace = gsl_multifit_linear_alloc(numNeighbors*3, 3*3);
		Xmatrix = gsl_matrix_alloc(numNeighbors*3, 3*3);
		Yvector = gsl_vector_alloc(numNeighbors*3);
		Cvector = gsl_vector_alloc(3*3);
		COVmatrix = gsl_matrix_alloc(3*3, 3*3);
		gsl_matrix_set_zero(Xmatrix);
	}
	~LeastSquareWorkspace() {
		gsl_multifit_linear_free(lsqFitWorkspace);
		gsl_matrix_free(Xmatrix);
		gsl_vector_free(Yvector);
		gsl_vector_free(Cvector);
		gsl_matrix_free(COVmatrix);
	}
	gsl_multifit_linear_workspace* lsqFitWorkspace;
	gsl_matrix* Xmatrix;
	gsl_vector* Yvector;
	gsl_vector* Cvector;
	gsl_matrix* COVmatrix;
};


/******************************************************************************
* Calculates the local deformation of each atom.
******************************************************************************/
bool CalculateExtrinsicStrainModifier::calculate(AtomsObject* atomsObject, bool suppressDialogs)
{
	// Cache some values for later use.
	const array<bool, 3> pbc = atomsObject->simulationCell()->periodicity();
	const AffineTransformation simulationCell = atomsObject->simulationCell()->cellMatrix();
	const AffineTransformation reciprocalSimulationCell = simulationCell.inverse();

	// Get the position channel.
	DataChannel* positionChannel = atomsObject->getStandardDataChannel(DataChannel::PositionChannel);
	if(!positionChannel) throw Exception(tr("Input atoms object does not contain a position channel."));

	// Make sure that the displacement analysis has been performed on these atoms.
	DataChannel* displacementChannel = atomsObject->getStandardDataChannel(DataChannel::DisplacementChannel);
	if(!displacementChannel) throw Exception(tr("Input atoms object does not contain a displacement channel. The Calculate Displacements modifier has to be applied first."));

	ProgressIndicator progress(tr("Calculating extrinsic strain tensors."), atomsObject->atomsCount(), suppressDialogs);

	// Prepare the neighbor list.
	OnTheFlyNeighborList neighborList(nearestNeighborList()->nearestNeighborCutoff());
	if(!neighborList.prepare(atomsObject, suppressDialogs)) {
		deformationGradientChannel->setSize(0);
		strainTensorChannel->setSize(0);
		hydrostaticStrainChannel->setSize(0);
		shearStrainChannel->setSize(0);
		return false;
	}

	// Prepare the output channels.
	deformationGradientChannel->setSize(atomsObject->atomsCount());
	strainTensorChannel->setSize(atomsObject->atomsCount());
	hydrostaticStrainChannel->setSize(atomsObject->atomsCount());
	shearStrainChannel->setSize(atomsObject->atomsCount());

	// For each number of nearest-neighbors this map contains the
	// matching workspace for the linear least-square rouine. This is needed
	// because allocation of the workspace memory takes too much time to do
	// it for each atom individually.
	QMap<int, shared_ptr<LeastSquareWorkspace> > lsqWorkspaces;

	int numUnderdetermined = 0;
	int numInvalid = 0;

	vector<Vector3> referenceVectors;
	vector<Vector3> deformedVectors;

	// Iterate over all atoms.
	for(int currentAtomIndex = 0; currentAtomIndex < atomsObject->atomsCount(); currentAtomIndex++) {

		// Update progress indicator.
		if((currentAtomIndex % 4096) == 0) {
			progress.setValue(currentAtomIndex);
			if(progress.isCanceled()) {
				// Throw away results obtained so far if the user cancels the calculation.
				deformationGradientChannel->setSize(0);
				strainTensorChannel->setSize(0);
				hydrostaticStrainChannel->setSize(0);
				shearStrainChannel->setSize(0);
				return false;
			}
		}

		// Gather neighbor vectors of current atom.
		int numNeighbors = 0;
		const Vector3 centerDispl = displacementChannel->getVector3(currentAtomIndex);
		for(OnTheFlyNeighborList::iterator neighborIter(neighborList, currentAtomIndex); !neighborIter.atEnd(); neighborIter.next(), numNeighbors++) {
			if(numNeighbors >= referenceVectors.size()) {
				deformedVectors.push_back(neighborIter.delta());
				referenceVectors.push_back(neighborIter.delta() + centerDispl - displacementChannel->getVector3(neighborIter.current()));
			}
			else {
				deformedVectors[numNeighbors] = neighborIter.delta();
				referenceVectors[numNeighbors] = neighborIter.delta() + centerDispl - displacementChannel->getVector3(neighborIter.current());
			}
		}

		// Use GSL to do a least-square fit.
		// Find a deformation gradient that minimizes the value of CHISQUARE.
		// CHISQUARE is defined as the sum of all squared distances between the
		// nearest neighbors and their respective positions in the reference configuration.

		// The linear least-square fit finds a vector C such that
		// Y = X * C,
		// with Y being a vector that contains the positions of all nearest-neighbors in the
		// deformed configuration and X being a matrix that contains the positions of all nearest-neighbors
		// in the reference configuration.
		// The matrix C then contains the optimum deformation tensor.

		// For atoms with less than 3 neighbors the strain tensor is not defined.
		if(numNeighbors < 3) {
			deformationGradientChannel->setTensor2(currentAtomIndex, NULL_MATRIX);
			strainTensorChannel->setSymmetricTensor2(currentAtomIndex, SymmetricTensor2(0));
			hydrostaticStrainChannel->setFloat(currentAtomIndex, 0);
			shearStrainChannel->setFloat(currentAtomIndex, 0);
			numUnderdetermined++;
			continue;
		}

		// Allocate temporary objects used to feed data to the GSL routine.
		LeastSquareWorkspace* ws = lsqWorkspaces[numNeighbors].get();
		if(!ws) {
			ws = new LeastSquareWorkspace(numNeighbors);
			lsqWorkspaces.insert(numNeighbors, shared_ptr<LeastSquareWorkspace>(ws));
		}

		// Setup vectors and matrices.
		double CHISQ=0;
		for(size_t i=0; i<numNeighbors; i++) {

			const Vector3& r0 = referenceVectors[i];
			const Vector3& r1 = deformedVectors[i];

			// Store vectors in the GSL data structure.
			for(int k=0; k<3; k++) {
				gsl_vector_set(ws->Yvector, i*3+k, r1[k]);
				for(int j=0; j<3; j++)
					gsl_matrix_set(ws->Xmatrix, i*3+k, k*3+j, r0[j]);
			}
		}

		// Do linear least-square fit.
		int errorCode = gsl_multifit_linear(ws->Xmatrix, ws->Yvector, ws->Cvector, ws->COVmatrix, &CHISQ, ws->lsqFitWorkspace);
		if(errorCode)
			throw Exception(tr("The gsl_multifit_linear() function failed to do a linear least-square fit for atom %1. Error code: %2").arg(currentAtomIndex).arg(errorCode));

		// Store tensor in output channel.
		Tensor2& F = *(deformationGradientChannel->dataTensor2() + currentAtomIndex);
		for(int i=0; i<3; i++)
			for(int j=0; j<3; j++)
				F(i, j) = gsl_vector_get(ws->Cvector, i*3+j);

		FloatType det = F.determinant();
		if(det <= 0.0 || det > 20.0) {
			deformationGradientChannel->setTensor2(currentAtomIndex, NULL_MATRIX);
			strainTensorChannel->setSymmetricTensor2(currentAtomIndex, SymmetricTensor2(0));
			hydrostaticStrainChannel->setFloat(currentAtomIndex, 0);
			shearStrainChannel->setFloat(currentAtomIndex, 0);
			numInvalid++;
			continue;
		}

		// Calculate strain tensor.
		SymmetricTensor2 strain = (Product_AtA(F) - IDENTITY) * 0.5;
		strainTensorChannel->setSymmetricTensor2(currentAtomIndex, strain);

		// Calculate hydrostatic strain invariant.
		hydrostaticStrainChannel->setFloat(currentAtomIndex, (strain(0,0) + strain(1,1) + strain(2,2)) / 3.0);

		// Calculate shear strain invariant.
		shearStrainChannel->setFloat(currentAtomIndex, sqrt(
				square(strain(0,1)) + square(strain(1,2)) + square(strain(0,2)) +
				(square(strain(1,1) - strain(2,2)) + square(strain(0,0) - strain(2,2)) + square(strain(0,0) - strain(1,1))) / 6.0));
	}

	if(numUnderdetermined)
		MsgLogger() << "WARNING: Found" << numUnderdetermined << "atoms with less than three neighbor atoms in the cutoff radius. Could not calculate strain tensor for these undercoordinated atoms. Please increase cutoff radius." << endl;
	if(numInvalid)
		MsgLogger() << "WARNING: Found" << numInvalid << "atoms with invalid deformation gradient. Could not calculate strain tensor for these atoms. Please increase cutoff radius." << endl;

	return true;
}

IMPLEMENT_PLUGIN_CLASS(CalculateExtrinsicStrainModifierEditor, AtomsObjectModifierEditorBase)

/******************************************************************************
* Sets up the UI widgets of the editor.
******************************************************************************/
void CalculateExtrinsicStrainModifierEditor::createUI(const RolloutInsertionParameters& rolloutParams)
{
	// Create a rollout.
	QWidget* rollout = createRollout(tr("Calculate Extrinsic Strain"), rolloutParams);

    // Create the rollout contents.
	QVBoxLayout* layout1 = new QVBoxLayout(rollout);
	layout1->setContentsMargins(4,4,4,4);
	layout1->setSpacing(0);

	BooleanPropertyUI* autoUpdateUI = new BooleanPropertyUI(this, PROPERTY_FIELD_DESCRIPTOR(AtomsObjectAnalyzerBase, _autoUpdateOnTimeChange));
	layout1->addWidget(autoUpdateUI->checkBox());

	BooleanPropertyUI* saveResultsUI = new BooleanPropertyUI(this, "storeResultsWithScene", tr("Save results in scene file"));
	layout1->addWidget(saveResultsUI->checkBox());

	QPushButton* recalcButton = new QPushButton(tr("Calculate"), rollout);
	layout1->addSpacing(6);
	layout1->addWidget(recalcButton);
	connect(recalcButton, SIGNAL(clicked(bool)), this, SLOT(onCalculate()));

	// Status label.
	layout1->addSpacing(10);
	layout1->addWidget(statusLabel());

	// Open sub-editors for the result channels.
	//new SubObjectParameterUI(this, PROPERTY_FIELD_DESCRIPTOR(CalculateExtrinsicStrainModifier, deformationGradientChannel), rolloutParams.after(rollout).collapse());
	//new SubObjectParameterUI(this, PROPERTY_FIELD_DESCRIPTOR(CalculateExtrinsicStrainModifier, strainTensorChannel), rolloutParams.after(rollout).collapse());
	//new SubObjectParameterUI(this, PROPERTY_FIELD_DESCRIPTOR(CalculateExtrinsicStrainModifier, hydrostaticStrainChannel), rolloutParams.after(rollout).collapse());
	//new SubObjectParameterUI(this, PROPERTY_FIELD_DESCRIPTOR(CalculateExtrinsicStrainModifier, shearStrainChannel), rolloutParams.after(rollout).collapse());

	// Open a sub-editor for the NearestNeighborList sub-object.
	SubObjectParameterUI* subEditorUI = new SubObjectParameterUI(this, PROPERTY_FIELD_DESCRIPTOR(AtomsObjectAnalyzerBase, _nearestNeighborList), rolloutParams.before(rollout));
}

/******************************************************************************
* Is called when the user presses the "Calculate" button.
******************************************************************************/
void CalculateExtrinsicStrainModifierEditor::onCalculate()
{
	if(!editObject()) return;
	CalculateExtrinsicStrainModifier* modifier = static_object_cast<CalculateExtrinsicStrainModifier>(editObject());
	try {
		modifier->performAnalysis(ANIM_MANAGER.time());
	}
	catch(Exception& ex) {
		ex.prependGeneralMessage(tr("Failed to analyze atoms."));
		ex.showError();
	}
}

};	// End of namespace CrystalAnalysis
