package org.tribuo.classification.sgd.crf;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Path;
import java.util.logging.Logger;
import org.tribuo.classification.sequence.LabelSequenceEvaluation;
import org.tribuo.classification.sequence.LabelSequenceEvaluator;
import org.tribuo.classification.sequence.example.SequenceDataGenerator;
import org.tribuo.hash.HashCodeHasher;
import org.tribuo.hash.HashingOptions;
import org.tribuo.hash.MessageDigestHasher;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.optimisers.GradientOptimiserOptions;
import org.tribuo.sequence.HashingSequenceTrainer;
import org.tribuo.sequence.ImmutableSequenceDataset;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceTrainer;

/* loaded from: input_file:org/tribuo/classification/sgd/crf/SeqTest.class */
public class SeqTest {
    private static final Logger logger = Logger.getLogger(SeqTest.class.getName());

    /* renamed from: org.tribuo.classification.sgd.crf.SeqTest$1, reason: invalid class name */
    /* loaded from: input_file:org/tribuo/classification/sgd/crf/SeqTest$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tribuo$hash$HashingOptions$ModelHashingType = new int[HashingOptions.ModelHashingType.values().length];

        static {
            try {
                $SwitchMap$org$tribuo$hash$HashingOptions$ModelHashingType[HashingOptions.ModelHashingType.NONE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tribuo$hash$HashingOptions$ModelHashingType[HashingOptions.ModelHashingType.HC.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tribuo$hash$HashingOptions$ModelHashingType[HashingOptions.ModelHashingType.SHA1.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$tribuo$hash$HashingOptions$ModelHashingType[HashingOptions.ModelHashingType.SHA256.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:org/tribuo/classification/sgd/crf/SeqTest$CRFOptions.class */
    public static class CRFOptions implements Options {
        public GradientOptimiserOptions gradientOptions;

        @Option(charName = 'f', longName = "output-path", usage = "Path to serialize model to.")
        public Path outputPath;

        @Option(charName = 'd', longName = "dataset-name", usage = "Name of the example dataset, options are {gorilla}.")
        public String datasetName = "";

        @Option(charName = 'i', longName = "epochs", usage = "Number of SGD epochs.")
        public int epochs = 5;

        @Option(charName = 'o', longName = "print-model", usage = "Print out feature, label and other model details.")
        public boolean logModel = false;

        @Option(charName = 'p', longName = "logging-interval", usage = "Log the objective after <int> examples.")
        public int loggingInterval = 100;

        @Option(charName = 'r', longName = "seed", usage = "RNG seed.")
        public long seed = 1;

        @Option(longName = "shuffle", usage = "Shuffle the data each epoch (default: true).")
        public boolean shuffle = true;

        @Option(charName = 'u', longName = "train-dataset", usage = "Path to a serialised SequenceDataset used for training.")
        public Path trainDataset = null;

        @Option(charName = 'v', longName = "test-dataset", usage = "Path to a serialised SequenceDataset used for testing.")
        public Path testDataset = null;

        @Option(longName = "model-hashing-algorithm", usage = "Hash the model during training. Defaults to no hashing.")
        public HashingOptions.ModelHashingType modelHashingAlgorithm = HashingOptions.ModelHashingType.NONE;

        @Option(longName = "model-hashing-salt", usage = "Salt for hashing the model.")
        public String modelHashingSalt = "";

        public String getOptionsDescription() {
            return "Tests a linear chain CRF model on the specified dataset.";
        }
    }

    public static void main(String[] strArr) throws ClassNotFoundException, IOException {
        SequenceDataset sequenceDataset;
        SequenceDataset copyDataset;
        LabsLogFormatter.setAllLogFormatters();
        CRFOptions cRFOptions = new CRFOptions();
        try {
            ConfigurationManager configurationManager = new ConfigurationManager(strArr, cRFOptions);
            logger.info("Configuring gradient optimiser");
            StochasticGradientOptimiser optimiser = cRFOptions.gradientOptions.getOptimiser();
            logger.info(String.format("Set logging interval to %d", Integer.valueOf(cRFOptions.loggingInterval)));
            String str = cRFOptions.datasetName;
            boolean z = -1;
            switch (str.hashCode()) {
                case 209951074:
                    if (str.equals("gorilla")) {
                        z = true;
                        break;
                    }
                    break;
                case 1874604354:
                    if (str.equals("Gorilla")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                case true:
                    logger.info("Generating gorilla dataset");
                    sequenceDataset = SequenceDataGenerator.generateGorillaDataset(1);
                    copyDataset = SequenceDataGenerator.generateGorillaDataset(1);
                    break;
                default:
                    if (cRFOptions.trainDataset == null || cRFOptions.testDataset == null) {
                        logger.warning("Unknown dataset " + cRFOptions.datasetName);
                        logger.info(configurationManager.usage());
                        return;
                    }
                    logger.info("Loading training data from " + cRFOptions.trainDataset);
                    ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(cRFOptions.trainDataset.toFile())));
                    Throwable th = null;
                    try {
                        ObjectInputStream objectInputStream2 = new ObjectInputStream(new BufferedInputStream(new FileInputStream(cRFOptions.testDataset.toFile())));
                        Throwable th2 = null;
                        try {
                            try {
                                sequenceDataset = (SequenceDataset) objectInputStream.readObject();
                                logger.info(String.format("Loaded %d training examples for %s", Integer.valueOf(sequenceDataset.size()), sequenceDataset.getOutputs().toString()));
                                logger.info("Found " + sequenceDataset.getFeatureIDMap().size() + " features");
                                logger.info("Loading testing data from " + cRFOptions.testDataset);
                                copyDataset = ImmutableSequenceDataset.copyDataset((SequenceDataset) objectInputStream2.readObject(), sequenceDataset.getFeatureIDMap(), sequenceDataset.getOutputIDInfo());
                                logger.info(String.format("Loaded %d testing examples", Integer.valueOf(copyDataset.size())));
                                if (objectInputStream2 != null) {
                                    if (0 != 0) {
                                        try {
                                            objectInputStream2.close();
                                        } catch (Throwable th3) {
                                            th2.addSuppressed(th3);
                                        }
                                    } else {
                                        objectInputStream2.close();
                                    }
                                }
                                if (objectInputStream != null) {
                                    if (th == null) {
                                        break;
                                    } else {
                                        try {
                                            break;
                                        } catch (Throwable th4) {
                                            break;
                                        }
                                    }
                                }
                            } catch (Throwable th5) {
                                th2 = th5;
                                throw th5;
                            }
                        } catch (Throwable th6) {
                            if (objectInputStream2 != null) {
                                if (th2 != null) {
                                    try {
                                        objectInputStream2.close();
                                    } catch (Throwable th7) {
                                        th2.addSuppressed(th7);
                                    }
                                } else {
                                    objectInputStream2.close();
                                }
                            }
                            throw th6;
                        }
                    } finally {
                        if (objectInputStream != null) {
                            if (0 != 0) {
                                try {
                                    objectInputStream.close();
                                } catch (Throwable th42) {
                                    th.addSuppressed(th42);
                                }
                            } else {
                                objectInputStream.close();
                            }
                        }
                    }
                    break;
            }
            SequenceTrainer cRFTrainer = new CRFTrainer(optimiser, cRFOptions.epochs, cRFOptions.loggingInterval, cRFOptions.seed);
            ((CRFTrainer) cRFTrainer).setShuffle(cRFOptions.shuffle);
            switch (AnonymousClass1.$SwitchMap$org$tribuo$hash$HashingOptions$ModelHashingType[cRFOptions.modelHashingAlgorithm.ordinal()]) {
                case 1:
                    break;
                case 2:
                    cRFTrainer = new HashingSequenceTrainer(cRFTrainer, new HashCodeHasher(cRFOptions.modelHashingSalt));
                    break;
                case 3:
                    cRFTrainer = new HashingSequenceTrainer(cRFTrainer, new MessageDigestHasher("SHA1", cRFOptions.modelHashingSalt));
                    break;
                case 4:
                    cRFTrainer = new HashingSequenceTrainer(cRFTrainer, new MessageDigestHasher("SHA-256", cRFOptions.modelHashingSalt));
                    break;
                default:
                    logger.info("Unknown hasher " + cRFOptions.modelHashingAlgorithm);
                    break;
            }
            logger.info("Training using " + cRFTrainer.toString());
            CRFModel train = cRFTrainer.train(sequenceDataset);
            logger.info("Finished training");
            if (cRFOptions.logModel) {
                System.out.println("FeatureMap = " + train.getFeatureIDMap().toString());
                System.out.println("LabelMap = " + train.getOutputIDInfo().toString());
                System.out.println("Features - " + train.generateWeightsString());
            }
            LabelSequenceEvaluation evaluate = new LabelSequenceEvaluator().evaluate(train, copyDataset);
            logger.info("Finished evaluating model");
            System.out.println(evaluate.toString());
            System.out.println();
            System.out.println(evaluate.getConfusionMatrix().toString());
            if (cRFOptions.outputPath != null) {
                FileOutputStream fileOutputStream = new FileOutputStream(cRFOptions.outputPath.toFile());
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream);
                objectOutputStream.writeObject(train);
                objectOutputStream.close();
                fileOutputStream.close();
                logger.info("Serialized model to file: " + cRFOptions.outputPath);
            }
        } catch (UsageException e) {
            logger.info(e.getMessage());
        }
    }
}
