/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types.tests;

import cc.mallet.pipe.Noop;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import org.junit.Assert;
import org.junit.Test;

public class TestInstanceList {
    @Test
    public void testSplit() {
        LabelAlphabet labelDict = new LabelAlphabet();
        labelDict.lookupIndex("0", true);
        labelDict.lookupIndex("1", true);
        String labelsString = "001011100100100";
        String[] instLabels = new String[labelsString.length()];
        for (int i = 0; i < instLabels.length; ++i) {
            instLabels[i] = Character.toString(labelsString.charAt(i));
        }
        double[] data = new double[]{};
        InstanceList ilist = new InstanceList(new Noop(null, labelDict));
        for (int i = 0; i < instLabels.length; ++i) {
            Label instLabel = labelDict.lookupLabel(instLabels[i], false);
            ilist.add(new Instance(data, instLabel, "i" + i, null));
        }
        int numFolds = 3;
        double foldRatio = 1.0 / (double)numFolds;
        double[] proportions = new double[numFolds];
        for (int i = 0; i < numFolds; ++i) {
            proportions[i] = foldRatio;
        }
        InstanceList[] instSplits = ilist.split(proportions);
        Assert.assertTrue(instSplits.length == 3);
        for (int i = 0; i < instSplits.length; ++i) {
            InstanceList splitList = instSplits[i];
            Assert.assertTrue(splitList.size() == labelsString.length() / 3);
        }
    }

    @Test
    public void testStratifiedSplit() {
        String labelsString = "001011100100100";
        this.testTemplate(labelsString, 3);
        System.out.println();
        this.testTemplate(labelsString, 5);
        System.out.println();
        this.testTemplate(labelsString, 6);
        System.out.println();
        labelsString = "000111222";
        this.testTemplate(labelsString, 3);
    }

    private void testTemplate(String labelsString, int numFolds) {
        LabelAlphabet labelDict = new LabelAlphabet();
        String[] instLabels = new String[labelsString.length()];
        for (int i = 0; i < instLabels.length; ++i) {
            instLabels[i] = Character.toString(labelsString.charAt(i));
            labelDict.lookupIndex(instLabels[i], true);
        }
        double[] data = new double[]{};
        InstanceList ilist = new InstanceList(null, labelDict);
        for (int i = 0; i < instLabels.length; ++i) {
            Label instLabel = labelDict.lookupLabel(instLabels[i], false);
            ilist.add(new Instance(data, instLabel, "i" + i, null));
        }
        double foldRatio = 1.0 / (double)numFolds;
        double[] proportions = new double[numFolds];
        for (int i = 0; i < numFolds; ++i) {
            proportions[i] = foldRatio;
        }
        InstanceList[] instSplits = ilist.stratifiedSplitInOrder(proportions);
        Assert.assertTrue(instSplits.length == numFolds);
        for (int i = 0; i < instSplits.length; ++i) {
            InstanceList splitList = instSplits[i];
            if (labelsString.length() % numFolds == 0) {
                Assert.assertTrue(splitList.size() == labelsString.length() / numFolds);
                continue;
            }
            Assert.assertTrue(splitList.size() >= labelsString.length() / numFolds);
        }
    }
}

