/*
 * Decompiled with CFR 0.152.
 */
package beast.evolution.speciation;

import beast.core.BEASTInterface;
import beast.core.Description;
import beast.core.Function;
import beast.core.Input;
import beast.core.StateNode;
import beast.core.StateNodeInitialiser;
import beast.core.parameter.RealParameter;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.Taxon;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.alignment.distance.Distance;
import beast.evolution.alignment.distance.JukesCantorDistance;
import beast.evolution.speciation.CalibratedYuleModel;
import beast.evolution.speciation.CalibrationPoint;
import beast.evolution.speciation.SpeciesTreePrior;
import beast.evolution.tree.Node;
import beast.evolution.tree.RandomTree;
import beast.evolution.tree.Tree;
import beast.evolution.tree.coalescent.ConstantPopulation;
import beast.math.distributions.MRCAPrior;
import beast.util.ClusterTree;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math.MathException;

@Description(value="Set a starting point for a *BEAST analysis from gene alignment data.")
public class StarBeastStartState
extends Tree
implements StateNodeInitialiser {
    public final Input<Method> initMethod = new Input<Method>("method", "Initialise either with a totally random state or a point estimate based on alignments data (default point-estimate)", Method.POINT, Method.values());
    public final Input<Tree> speciesTreeInput = new Input("speciesTree", "The species tree to initialize");
    public final Input<List<Tree>> genes = new Input("gene", "Gene trees to initialize", new ArrayList());
    public final Input<CalibratedYuleModel> calibratedYule = new Input("calibratedYule", "The species tree (with calibrations) to initialize", Input.Validate.XOR, this.speciesTreeInput);
    public final Input<RealParameter> popMean = new Input("popMean", "Population mean hyper prior to initialse");
    public final Input<RealParameter> birthRate = new Input("birthRate", "Tree prior birth rate to initialize");
    public final Input<SpeciesTreePrior> speciesTreePriorInput = new Input("speciesTreePrior", "Population size parameters to initialise");
    public final Input<Function> muInput = new Input("baseRate", "Main clock rate used to scale trees (default 1).");
    private boolean hasCalibrations;

    @Override
    public void initAndValidate() {
        super.initAndValidate();
        this.hasCalibrations = this.calibratedYule.get() != null;
    }

    @Override
    public void initStateNodes() {
        Set<BEASTInterface> set = this.speciesTreeInput.get().getOutputs();
        ArrayList<MRCAPrior> arrayList = new ArrayList<MRCAPrior>();
        for (BEASTInterface bEASTInterface : set) {
            if (!(bEASTInterface instanceof MRCAPrior)) continue;
            arrayList.add((MRCAPrior)bEASTInterface);
        }
        if (this.hasCalibrations) {
            if (arrayList.size() > 0) {
                throw new IllegalArgumentException("Not implemented: mix of calibrated yule and MRCA priors: place all priors in the calibrated Yule");
            }
            try {
                this.initWithCalibrations();
            }
            catch (MathException mathException) {
                throw new IllegalArgumentException(mathException);
            }
        } else {
            if (arrayList.size() > 0) {
                this.initWithMRCACalibrations(arrayList);
                return;
            }
            Method method = this.initMethod.get();
            switch (method) {
                case POINT: {
                    this.fullInit();
                    break;
                }
                case ALL_RANDOM: {
                    this.randomInit();
                }
            }
        }
    }

    private double[] firstMeetings(Tree tree, Map<String, Integer> map, int n) {
        Node[] nodeArray = tree.listNodesPostOrder(null, null);
        Set[] setArray = new Set[nodeArray.length];
        for (int i = 0; i < setArray.length; ++i) {
            setArray[i] = new HashSet();
        }
        double[] dArray = new double[n * (n - 1) / 2];
        Arrays.fill(dArray, Double.MAX_VALUE);
        for (Node node : nodeArray) {
            if (node.isLeaf()) {
                setArray[node.getNr()].add(map.get(node.getID()));
                continue;
            }
            assert (node.getChildCount() == 2);
            Set[] setArray2 = new Set[]{setArray[node.getChild(0).getNr()], setArray[node.getChild(1).getNr()]};
            HashSet hashSet = new HashSet(setArray2[0]);
            hashSet.retainAll(setArray2[1]);
            setArray2[0].removeAll(hashSet);
            setArray2[1].removeAll(hashSet);
            for (Integer n2 : setArray2[0]) {
                for (Integer n3 : setArray2[1]) {
                    int n4 = this.getDMindex(n, n2, n3);
                    dArray[n4] = Math.min(dArray[n4], node.getHeight());
                }
            }
            hashSet.addAll(setArray2[0]);
            hashSet.addAll(setArray2[1]);
            setArray[node.getNr()] = hashSet;
        }
        return dArray;
    }

    private int getDMindex(int n, int n2, int n3) {
        int n4 = Math.min(n2, n3);
        return n4 * (2 * n - 1 - n4) / 2 + (Math.abs(n2 - n3) - 1);
    }

    private void fullInit() {
        SpeciesTreePrior speciesTreePrior;
        Node[] nodeArray;
        int n;
        int clusterTree;
        Object object;
        Function function = this.muInput.get();
        final double d = function != null ? function.getArrayValue() : 1.0;
        Tree tree = this.speciesTreeInput.get();
        TaxonSet taxonSet = tree.m_taxonset.get();
        List<String> list = taxonSet.asStringList();
        final int n3 = list.size();
        List<Tree> list2 = this.genes.get();
        double d2 = 0.0;
        for (Tree object32 : list2) {
            Alignment i = object32.m_taxonset.get().alignmentInput.get();
            object = new ClusterTree();
            object.initByName("initial", object32, "clusterType", "upgma", "taxa", i);
            object32.scale(1.0 / d);
            d2 = Math.max(d2, (double)i.getSiteCount());
        }
        HashMap hashMap2 = new HashMap();
        List<Taxon> list3 = taxonSet.taxonsetInput.get();
        for (int dArray = 0; dArray < list.size(); ++dArray) {
            object = list3.get(dArray);
            List<Taxon> n2 = ((TaxonSet)object).taxonsetInput.get();
            for (Taxon taxon : n2) {
                hashMap2.put(taxon.getID(), dArray);
            }
        }
        final double[] dArray = new double[n3 * (n3 - 1) / 2];
        object = new double[list2.size()][];
        for (clusterTree = 0; clusterTree < list2.size(); ++clusterTree) {
            Tree d3 = list2.get(clusterTree);
            double[] dArray2 = this.firstMeetings(d3, hashMap2, n3);
            object[clusterTree] = dArray2;
            for (n = 0; n < dArray2.length; ++n) {
                String string;
                int n2 = n;
                dArray[n2] = dArray[n2] + dArray2[n];
                if (dArray2[n] != Double.MAX_VALUE) continue;
                String string2 = string = n < n3 - 1 ? tree.getExternalNodes().get(n + 1).getID() : "unknown taxon";
                if (n == 0) {
                    boolean bl = true;
                    for (int i = 1; bl && i < n3 - 1; ++i) {
                        bl = dArray2[i] == Double.MAX_VALUE;
                    }
                    if (bl) {
                        string = tree.getExternalNodes().get(0).getID();
                    }
                }
                throw new RuntimeException("Gene tree " + d3.getID() + " has no lineages for species taxon " + string + " ");
            }
        }
        for (clusterTree = 0; clusterTree < dArray.length; ++clusterTree) {
            double distance = dArray[clusterTree] / (double)list2.size();
            distance = distance == 0.0 ? 0.5 / d2 * (1.0 / d) : (distance *= 2.0);
            dArray[clusterTree] = distance;
        }
        ClusterTree clusterTree2 = new ClusterTree();
        Distance distance = new Distance(){

            @Override
            public double pairwiseDistance(int n, int n2) {
                int n32 = StarBeastStartState.this.getDMindex(n3, n, n2);
                return dArray[n32];
            }
        };
        clusterTree2.initByName("initial", tree, "taxonset", taxonSet, "clusterType", "upgma", "distance", distance);
        HashMap<String, Integer> hashMap = new HashMap<String, Integer>();
        for (n = 0; n < list.size(); ++n) {
            hashMap.put(list.get(n), n);
        }
        final double[] dArray3 = this.firstMeetings(tree, hashMap, n3);
        for (int i = 0; i < list2.size(); ++i) {
            Object object2 = object[i];
            boolean bl = true;
            for (int j = 0; j < dArray3.length; ++j) {
                if (!(object2[j] <= dArray3[j])) continue;
                bl = false;
                break;
            }
            if (bl) continue;
            Tree tree2 = list2.get(i);
            TaxonSet taxonSet2 = tree2.m_taxonset.get();
            Alignment alignment = taxonSet2.alignmentInput.get();
            List<String> list4 = alignment.getTaxaNames();
            int n4 = list4.size();
            final HashMap hashMap3 = new HashMap();
            for (int j = 0; j < n4; ++j) {
                hashMap3.put(j, hashMap2.get(list4.get(j)));
            }
            final JukesCantorDistance jukesCantorDistance = new JukesCantorDistance();
            jukesCantorDistance.setPatterns(alignment);
            Distance distance2 = new Distance(){

                @Override
                public double pairwiseDistance(int n, int n2) {
                    int n32;
                    double d3;
                    int n4 = (Integer)hashMap3.get(n);
                    int n5 = (Integer)hashMap3.get(n2);
                    double d2 = jukesCantorDistance.pairwiseDistance(n, n2) / d;
                    if (n4 != n5 && d2 <= (d3 = 2.0 * dArray3[n32 = StarBeastStartState.this.getDMindex(n3, n4, n5)])) {
                        d2 = d3 * 1.001;
                    }
                    return d2;
                }
            };
            ClusterTree clusterTree3 = new ClusterTree();
            clusterTree3.initByName("initial", tree2, "taxonset", taxonSet2, "clusterType", "upgma", "distance", distance2);
        }
        RealParameter realParameter = this.birthRate.get();
        if (realParameter != null) {
            double d3 = tree.getRoot().getHeight();
            double d4 = 0.0;
            for (int i = 2; i < n3 + 1; ++i) {
                d4 += 1.0 / (double)i;
            }
            this.setParameterValue(realParameter, 1.0 / d3 * d4);
        }
        double d5 = 0.0;
        for (Node node : nodeArray = tree.getNodesAsArray()) {
            if (node.isRoot()) continue;
            d5 += node.getLength();
        }
        d5 /= (double)(2 * (nodeArray.length - 1));
        RealParameter realParameter2 = this.popMean.get();
        if (realParameter2 != null) {
            this.setParameterValue(realParameter2, d5);
        }
        if ((speciesTreePrior = this.speciesTreePriorInput.get()) != null) {
            RealParameter realParameter3;
            RealParameter realParameter4 = speciesTreePrior.popSizesBottomInput.get();
            if (realParameter4 != null) {
                for (int i = 0; i < realParameter4.getDimension(); ++i) {
                    this.setParameterValue(realParameter4, i, 2.0 * d5);
                }
            }
            if ((realParameter3 = speciesTreePrior.popSizesTopInput.get()) != null) {
                for (int i = 0; i < realParameter3.getDimension(); ++i) {
                    this.setParameterValue(realParameter3, i, d5);
                }
            }
        }
    }

    private void setParameterValue(RealParameter realParameter, double d) {
        this.setParameterValue(realParameter, 0, d);
    }

    private void setParameterValue(RealParameter realParameter, int n, double d) {
        if (d < (Double)realParameter.getLower()) {
            d = (Double)realParameter.getLower();
        }
        if (d > (Double)realParameter.getUpper()) {
            d = (Double)realParameter.getUpper();
        }
        realParameter.setValue(n, d);
    }

    private void randomInitGeneTrees(double d) {
        List<Tree> list = this.genes.get();
        for (Tree tree : list) {
            tree.makeCaterpillar(d, d / (double)tree.getInternalNodeCount(), true);
        }
    }

    private void randomInit() {
        double d = 1.0;
        RealParameter realParameter = this.birthRate.get();
        if (realParameter != null) {
            d = realParameter.getArrayValue();
        }
        Tree tree = this.speciesTreeInput.get();
        TaxonSet taxonSet = tree.m_taxonset.get();
        int n = taxonSet.asStringList().size();
        double d2 = 0.0;
        for (int i = 2; i <= n; ++i) {
            d2 += 1.0 / (double)i;
        }
        double d3 = 1.0 / d * d2;
        tree.scale(d3 / tree.getRoot().getHeight());
        this.randomInitGeneTrees(d3);
    }

    private void initWithCalibrations() throws MathException {
        CalibratedYuleModel calibratedYuleModel = this.calibratedYule.get();
        Tree tree = (Tree)calibratedYuleModel.treeInput.get();
        List<CalibrationPoint> list = calibratedYuleModel.calibrationsInput.get();
        CalibratedYuleModel calibratedYuleModel2 = new CalibratedYuleModel();
        calibratedYuleModel2.getOutputs().addAll(calibratedYuleModel.getOutputs());
        for (CalibrationPoint calibrationPoint : list) {
            calibratedYuleModel2.setInputValue("calibrations", calibrationPoint);
        }
        calibratedYuleModel2.setInputValue("tree", tree);
        calibratedYuleModel2.setInputValue("type", (Object)CalibratedYuleModel.Type.NONE);
        calibratedYuleModel2.initAndValidate();
        Tree tree2 = calibratedYuleModel2.compatibleInitialTree();
        assert (tree.getLeafNodeCount() == tree2.getLeafNodeCount());
        tree.assignFromWithoutID(tree2);
        double d = tree.getRoot().getHeight();
        this.randomInitGeneTrees(d);
        calibratedYuleModel.initAndValidate();
    }

    private void initWithMRCACalibrations(List<MRCAPrior> list) {
        Tree tree = this.speciesTreeInput.get();
        RandomTree randomTree = new RandomTree();
        randomTree.setInputValue("taxonset", tree.getTaxonset());
        for (MRCAPrior mRCAPrior : list) {
            randomTree.setInputValue("constraint", mRCAPrior);
        }
        ConstantPopulation constantPopulation = new ConstantPopulation();
        constantPopulation.setInputValue("popSize", new RealParameter("1.0"));
        randomTree.setInputValue("populationModel", constantPopulation);
        randomTree.initAndValidate();
        tree.assignFromWithoutID(randomTree);
        double d = tree.getRoot().getHeight();
        this.randomInitGeneTrees(d);
    }

    @Override
    public void getInitialisedStateNodes(List<StateNode> list) {
        SpeciesTreePrior speciesTreePrior;
        StateNode stateNode2;
        if (this.hasCalibrations) {
            list.add((Tree)this.calibratedYule.get().treeInput.get());
        } else {
            list.add(this.speciesTreeInput.get());
        }
        for (StateNode stateNode2 : this.genes.get()) {
            list.add(stateNode2);
        }
        RealParameter realParameter = this.popMean.get();
        if (realParameter != null) {
            list.add(realParameter);
        }
        if ((stateNode2 = this.birthRate.get()) != null) {
            list.add(stateNode2);
        }
        if ((speciesTreePrior = this.speciesTreePriorInput.get()) != null) {
            RealParameter realParameter2;
            RealParameter realParameter3 = speciesTreePrior.popSizesBottomInput.get();
            if (realParameter3 != null) {
                list.add(realParameter3);
            }
            if ((realParameter2 = speciesTreePrior.popSizesTopInput.get()) != null) {
                list.add(realParameter2);
            }
        }
    }

    static enum Method {
        POINT("point-estimate"),
        ALL_RANDOM("random");

        private final String ename;

        private Method(String string2) {
            this.ename = string2;
        }

        public String toString() {
            return this.ename;
        }
    }
}

