/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.util.Merger;

public class HeapMerger
implements Merger {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(HeapMerger.class.getName());

    @Override
    public DenseSparseMatrix merge(DenseSparseMatrix[] inputs) {
        int denseLength = inputs[0].getDimension1Size();
        int sparseLength = inputs[0].getDimension2Size();
        int[] totalLengths = new int[inputs[0].getDimension1Size()];
        for (int i = 0; i < inputs.length; ++i) {
            for (int j = 0; j < totalLengths.length; ++j) {
                int n = j;
                totalLengths[n] = totalLengths[n] + inputs[i].numActiveElements(j);
            }
        }
        int maxLength = 0;
        for (int i = 0; i < totalLengths.length; ++i) {
            if (totalLengths[i] <= maxLength) continue;
            maxLength = totalLengths[i];
        }
        SparseVector[] output = new SparseVector[denseLength];
        int[] indicesBuffer = new int[maxLength];
        double[] valuesBuffer = new double[maxLength];
        ArrayList<SparseVector> vectors = new ArrayList<SparseVector>();
        for (int i = 0; i < denseLength; ++i) {
            vectors.clear();
            for (DenseSparseMatrix m : inputs) {
                SparseVector vec = m.getRow(i);
                if (vec.numActiveElements() <= 0) continue;
                vectors.add(vec);
            }
            output[i] = HeapMerger.merge(vectors, sparseLength, indicesBuffer, valuesBuffer);
        }
        return DenseSparseMatrix.createFromSparseVectors(output);
    }

    @Override
    public SparseVector merge(SparseVector[] inputs) {
        int maxLength = 0;
        for (int i = 0; i < inputs.length; ++i) {
            maxLength += inputs[i].numActiveElements();
        }
        return HeapMerger.merge(Arrays.asList(inputs), inputs[0].size(), new int[maxLength], new double[maxLength]);
    }

    public static SparseVector merge(List<SparseVector> vectors, int dimension, int[] indicesBuffer, double[] valuesBuffer) {
        VectorIterator cur;
        PriorityQueue<VectorIterator> queue = new PriorityQueue<VectorIterator>();
        Arrays.fill(valuesBuffer, 0.0);
        for (SparseVector vector : vectors) {
            cur = vector.iterator();
            cur.next();
            queue.add(cur);
        }
        int sparseCounter = 0;
        int sparseIndex = -1;
        while (!queue.isEmpty()) {
            cur = (VectorIterator)queue.peek();
            VectorTuple ref = cur.getReference();
            if (sparseIndex == -1) {
                indicesBuffer[sparseCounter] = sparseIndex = ref.index;
                valuesBuffer[sparseCounter] = ref.value;
            } else if (ref.index == sparseIndex) {
                int n = sparseCounter;
                valuesBuffer[n] = valuesBuffer[n] + ref.value;
            } else {
                sparseIndex = ref.index;
                indicesBuffer[++sparseCounter] = sparseIndex;
                valuesBuffer[sparseCounter] = ref.value;
            }
            if (!cur.hasNext()) {
                queue.poll();
                continue;
            }
            cur.next();
            VectorIterator tmp = (VectorIterator)queue.poll();
            queue.offer(tmp);
        }
        int[] indices = Arrays.copyOf(indicesBuffer, sparseCounter + 1);
        double[] values = Arrays.copyOf(valuesBuffer, sparseCounter + 1);
        return SparseVector.createSparseVector(dimension, indices, values);
    }
}

