"""
An example of using the Bayesian sampling library emcee with BornAgain.

author: Andrew McCluskey (andrew.mccluskey@ess.eu)
"""

# Import necessary modules
from os import path, getenv
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import differential_evolution
import emcee
import corner
import bornagain as ba

np.random.seed(1)

datadir = getenv('BA_DATA_DIR', '')


# The sample
def get_sample(ni_thickness, ti_thickness):
    """
    Creates a sample and returns it
    :float ni_thickness: a value of the Ni thickness in nanometres
    :float ti_thickness: a value of the Ti thickness in nanometres
    :return: the sample defined
    """

    # pure real scattering-length densities (in angstrom^-2)
    si_sld_real = 2.0704e-06  # Si (substrate)
    ni_sld_real = 9.4245e-06  # Ni
    ti_sld_real = -1.9493e-06  # Ti

    n_repetitions = 10

    # defining materials
    vacuum = ba.MaterialBySLD()
    material_ni = ba.MaterialBySLD("Ni", ni_sld_real, 0)
    material_ti = ba.MaterialBySLD("Ti", ti_sld_real, 0)
    material_substrate = ba.MaterialBySLD("SiSubstrate", si_sld_real, 0)

    # vacuum layer and substrate form multilayer
    vacuum_layer = ba.Layer(vacuum)
    ni_layer = ba.Layer(material_ni, ni_thickness)
    ti_layer = ba.Layer(material_ti, ti_thickness)
    substrate_layer = ba.Layer(material_substrate)
    sample = ba.MultiLayer()
    sample.addLayer(vacuum_layer)
    for _ in range(n_repetitions):
        sample.addLayer(ti_layer)
        sample.addLayer(ni_layer)
    sample.addLayer(substrate_layer)
    return sample


# Source the real data and add an uncertainty to the ordinate
def get_real_data():
    """
    Loading data from genx_alternating_layers.dat
    A Nx3 array (N - the number of experimental data entries)
    with first column being coordinates,
    second one being values, and the third the uncertainties.
    """
    if not hasattr(get_real_data, "data"):
        filepath = path.join(datadir, 'genx_alternating_layers.dat.gz')
        real_data = ba.readData2D(filepath).npArray()
        # translating axis values from double incident angle (degs)
        # to incident angle (radians)
        real_data[:, 0] *= np.pi/360
        # setting artificial uncertainties (uncertainty sigma equals a ten
        # percent of experimental data value)
        real_data[:, 2] = real_data[:, 1]*0.1
    return real_data


# The simulation
def get_simulation(sample, alpha):
    wavelength = 0.154  #nm
    scan = ba.AlphaScan(wavelength, alpha)
    return ba.SpecularSimulation(scan, sample)


# Run the simulation
def run_simulation(alpha, ni_thickness, ti_thickness):
    """
    Runs simulation and returns its result.
    :array q: q-values to be simulated
    :float ni_thickness: a value of the Ni thickness
    :float ti_thickness: a value of the Ti thickness
    :return: simulated reflected intensity
    """
    sample = get_sample(ni_thickness, ti_thickness)
    simulation = get_simulation(sample, alpha)

    result = simulation.simulate()
    return result.array()


# A log-likelihood function
def log_likelihood(theta, x, y, yerr):
    """
    Calculate the log-likelihood for the normal uncertainties

    :tuple theta: the variable parameters
    :array x: the abscissa data (q-values)
    :array y: the ordinate data (R-values)
    :array x: the ordinate uncertainty (dR-values)
    :return: log-likelihood
    """
    model = run_simulation(x, *theta)
    sigma2 = yerr**2 + model**2
    return -0.5*np.sum((y - model)**2/sigma2 + np.log(sigma2))


if __name__ == '__main__':
    # Using scipy.optimize.differential_evolution find the
    # maximum likelihood estimate
    nll = lambda *args: -log_likelihood(*args)
    initial = np.array([9.0, 1.0]) + 0.1*np.random.randn(2)
    soln = differential_evolution(nll, ((5.0, 9.0), (1.0, 10.0)),
                                  args=(get_real_data()[:, 0],
                                        get_real_data()[:, 1],
                                        get_real_data()[:, 2]))
    ni_thickness_ml, ti_thickness_ml = soln.x
    print('MLE Ni Thickness', ni_thickness_ml, 'nm')
    print('MLE Ti Thickness', ti_thickness_ml, 'nm')

    # Perform the likelihood sampling
    pos = soln.x + 1e-4*np.random.randn(32, 2)
    nwalkers, ndim = pos.shape

    sampler = emcee.EnsembleSampler(nwalkers,
                                    ndim,
                                    log_likelihood,
                                    args=(get_real_data()[:, 0],
                                          get_real_data()[:, 1],
                                          get_real_data()[:, 2]))
    sampler.run_mcmc(pos, 1000, progress=True)

    # Plot and show corner plot of samples
    flat_samples = sampler.get_chain(flat=True)
    corner.corner(flat_samples,
                  labels=['Ni-thickness/nm', 'Ti-thickness/nm'])
    plt.show()

    # Plot and show MLE and data of reflectivity
    plt.errorbar(get_real_data()[:, 0],
                 get_real_data()[:, 1],
                 get_real_data()[:, 2],
                 marker='.',
                 ls='')
    plt.plot(
        get_real_data()[:, 0],
        run_simulation(get_real_data()[:, 0], *flat_samples.mean(axis=0)),
        '-')
    plt.xlabel('$\\alpha$/deg')
    plt.ylabel('$R$')
    plt.yscale('log')
    plt.show()
