package com.hankcs.hanlp.model.crf.crfpp;

import com.google.android.material.shadow.ShadowDrawableWrapper;
import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.model.crf.crfpp.TaggerImpl;
import e.b.a.a.a;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/* loaded from: classes.dex */
public class Encoder {
    public static int MODEL_VERSION = 100;

    /* renamed from: com.hankcs.hanlp.model.crf.crfpp.Encoder$1, reason: invalid class name */
    /* loaded from: classes.dex */
    public static /* synthetic */ class AnonymousClass1 {
        public static final /* synthetic */ int[] $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm;

        static {
            Algorithm.values();
            int[] iArr = new int[3];
            $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm = iArr;
            try {
                Algorithm algorithm = Algorithm.CRF_L1;
                iArr[1] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                int[] iArr2 = $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm;
                Algorithm algorithm2 = Algorithm.CRF_L2;
                iArr2[0] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                int[] iArr3 = $SwitchMap$com$hankcs$hanlp$model$crf$crfpp$Encoder$Algorithm;
                Algorithm algorithm3 = Algorithm.MIRA;
                iArr3[2] = 3;
            } catch (NoSuchFieldError unused3) {
            }
        }
    }

    /* loaded from: classes.dex */
    public enum Algorithm {
        CRF_L2,
        CRF_L1,
        MIRA;

        public static Algorithm fromString(String str) {
            String lowerCase = str.toLowerCase();
            if (lowerCase.equals("crf") || lowerCase.equals("crf-l2")) {
                return CRF_L2;
            }
            if (lowerCase.equals("crf-l1")) {
                return CRF_L1;
            }
            if (lowerCase.equals("mira")) {
                return MIRA;
            }
            throw new IllegalArgumentException(a.t("invalid algorithm: ", lowerCase));
        }
    }

    public static void main(String[] strArr) {
        if (strArr.length < 3) {
            System.err.println("incorrect No. of args");
            return;
        }
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        Encoder encoder = new Encoder();
        long time = new Date().getTime();
        if (encoder.learn(str, str2, str3, false, 100000, 1, 1.0E-4d, 1.0d, 1, 20, Algorithm.CRF_L2)) {
            System.out.println(new Date().getTime() - time);
        } else {
            System.err.println("error training model");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private boolean runCRF(List<TaggerImpl> list, EncoderFeatureIndex encoderFeatureIndex, double[] dArr, int i2, double d2, double d3, int i3, int i4, boolean z) {
        int size;
        LbfgsOptimizer lbfgsOptimizer = new LbfgsOptimizer();
        ArrayList arrayList = new ArrayList();
        boolean z2 = 0;
        for (int i5 = 0; i5 < i4; i5++) {
            CRFEncoderThread cRFEncoderThread = new CRFEncoderThread(dArr.length);
            cRFEncoderThread.start_i = i5;
            cRFEncoderThread.size = list.size();
            cRFEncoderThread.threadNum = i4;
            cRFEncoderThread.x = list;
            arrayList.add(cRFEncoderThread);
        }
        int i6 = 0;
        for (int i7 = 0; i7 < list.size(); i7++) {
            i6 += list.get(i7).size();
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i4);
        int i8 = 0;
        double d4 = 1.0E37d;
        int i9 = 0;
        while (i8 < i2) {
            encoderFeatureIndex.clear();
            try {
                newFixedThreadPool.invokeAll(arrayList);
                int i10 = 1;
                int i11 = z2;
                while (i10 < i4) {
                    ((CRFEncoderThread) arrayList.get(i11)).obj += ((CRFEncoderThread) arrayList.get(i10)).obj;
                    ((CRFEncoderThread) arrayList.get(0)).err += ((CRFEncoderThread) arrayList.get(i10)).err;
                    ((CRFEncoderThread) arrayList.get(0)).zeroone += ((CRFEncoderThread) arrayList.get(i10)).zeroone;
                    i10++;
                    i11 = 0;
                    i9 = i9;
                    d4 = d4;
                }
                int i12 = i9;
                double d5 = d4;
                for (int i13 = 1; i13 < i4; i13++) {
                    for (int i14 = 0; i14 < encoderFeatureIndex.size(); i14++) {
                        double[] dArr2 = ((CRFEncoderThread) arrayList.get(0)).expected;
                        dArr2[i14] = dArr2[i14] + ((CRFEncoderThread) arrayList.get(i13)).expected[i14];
                    }
                }
                if (z) {
                    size = 0;
                    for (int i15 = 0; i15 < encoderFeatureIndex.size(); i15++) {
                        CRFEncoderThread cRFEncoderThread2 = (CRFEncoderThread) arrayList.get(0);
                        cRFEncoderThread2.obj = Math.abs(dArr[i15] / d2) + cRFEncoderThread2.obj;
                        if (dArr[i15] != ShadowDrawableWrapper.r) {
                            size++;
                        }
                    }
                } else {
                    size = encoderFeatureIndex.size();
                    for (int i16 = 0; i16 < encoderFeatureIndex.size(); i16++) {
                        CRFEncoderThread cRFEncoderThread3 = (CRFEncoderThread) arrayList.get(0);
                        cRFEncoderThread3.obj = ((dArr[i16] * dArr[i16]) / (2.0d * d2)) + cRFEncoderThread3.obj;
                        double[] dArr3 = ((CRFEncoderThread) arrayList.get(0)).expected;
                        dArr3[i16] = (dArr[i16] / d2) + dArr3[i16];
                    }
                }
                for (int i17 = 1; i17 < i4; i17++) {
                    ((CRFEncoderThread) arrayList.get(i17)).expected = null;
                }
                double abs = i8 == 0 ? 1.0d : Math.abs(d5 - ((CRFEncoderThread) arrayList.get(0)).obj) / d5;
                StringBuilder E = a.E("iter=", i8, " terr=");
                E.append((((CRFEncoderThread) arrayList.get(0)).err * 1.0d) / i6);
                E.append(" serr=");
                E.append((((CRFEncoderThread) arrayList.get(0)).zeroone * 1.0d) / list.size());
                E.append(" act=");
                E.append(size);
                E.append(" obj=");
                E.append(((CRFEncoderThread) arrayList.get(0)).obj);
                E.append(" diff=");
                E.append(abs);
                System.out.println(E.toString());
                double d6 = ((CRFEncoderThread) arrayList.get(0)).obj;
                int i18 = abs < d3 ? i12 + 1 : 0;
                if (i8 > i2 || i18 == 3) {
                    break;
                }
                ExecutorService executorService = newFixedThreadPool;
                int i19 = i8;
                if (lbfgsOptimizer.optimize(encoderFeatureIndex.size(), dArr, ((CRFEncoderThread) arrayList.get(0)).obj, ((CRFEncoderThread) arrayList.get(0)).expected, z, d2) <= 0) {
                    return false;
                }
                z2 = 0;
                i8 = i19 + 1;
                newFixedThreadPool = executorService;
                i9 = i18;
                d4 = d6;
            } catch (Exception e2) {
                e2.printStackTrace();
                return z2;
            }
        }
        ExecutorService executorService2 = newFixedThreadPool;
        executorService2.shutdown();
        try {
            executorService2.awaitTermination(-1L, TimeUnit.SECONDS);
            return true;
        } catch (Exception e3) {
            e3.printStackTrace();
            System.err.println("fail waiting executor to shutdown");
            return true;
        }
    }

    public boolean learn(String str, String str2, String str3, boolean z, int i2, int i3, double d2, double d3, int i4, int i5, Algorithm algorithm) {
        if (d2 <= ShadowDrawableWrapper.r) {
            System.err.println("eta must be > 0.0");
            return false;
        }
        if (d3 < ShadowDrawableWrapper.r) {
            System.err.println("C must be >= 0.0");
            return false;
        }
        if (i5 < 1) {
            System.err.println("shrinkingSize must be >= 1");
            return false;
        }
        if (i4 <= 0) {
            System.err.println("thread must be  > 0");
            return false;
        }
        EncoderFeatureIndex encoderFeatureIndex = new EncoderFeatureIndex(i4);
        ArrayList arrayList = new ArrayList();
        if (!encoderFeatureIndex.open(str, str2)) {
            System.err.println("Fail to open " + str + " " + str2);
        }
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(str2), "UTF-8"));
            int i6 = 0;
            while (true) {
                TaggerImpl taggerImpl = new TaggerImpl(TaggerImpl.Mode.LEARN);
                taggerImpl.open(encoderFeatureIndex);
                TaggerImpl.ReadStatus read = taggerImpl.read(bufferedReader);
                if (read == TaggerImpl.ReadStatus.ERROR) {
                    System.err.println("error when reading " + str2);
                    return false;
                }
                if (taggerImpl.empty()) {
                    if (read == TaggerImpl.ReadStatus.EOF) {
                        bufferedReader.close();
                        encoderFeatureIndex.shrink(i3, arrayList);
                        double[] dArr = new double[encoderFeatureIndex.size()];
                        Arrays.fill(dArr, ShadowDrawableWrapper.r);
                        encoderFeatureIndex.setAlpha_(dArr);
                        PrintStream printStream = System.out;
                        StringBuilder D = a.D("Number of sentences: ");
                        D.append(arrayList.size());
                        printStream.println(D.toString());
                        PrintStream printStream2 = System.out;
                        StringBuilder D2 = a.D("Number of features:  ");
                        D2.append(encoderFeatureIndex.size());
                        printStream2.println(D2.toString());
                        System.out.println("Number of thread(s): " + i4);
                        System.out.println("Freq:                " + i3);
                        System.out.println("eta:                 " + d2);
                        System.out.println("C:                   " + d3);
                        System.out.println("shrinking size:      " + i5);
                        int ordinal = algorithm.ordinal();
                        if (ordinal != 0) {
                            if (ordinal != 1) {
                                if (ordinal == 2 && !runMIRA(arrayList, encoderFeatureIndex, dArr, i2, d3, d2, i5, i4)) {
                                    System.err.println("MIRA execute error");
                                    return false;
                                }
                            } else if (!runCRF(arrayList, encoderFeatureIndex, dArr, i2, d3, d2, i5, i4, true)) {
                                System.err.println("CRF_L1 execute error");
                                return false;
                            }
                        } else if (!runCRF(arrayList, encoderFeatureIndex, dArr, i2, d3, d2, i5, i4, false)) {
                            System.err.println("CRF_L2 execute error");
                            return false;
                        }
                        if (!encoderFeatureIndex.save(str3, z)) {
                            System.err.println("Failed to save model");
                        }
                        System.out.println("Done!");
                        return true;
                    }
                } else {
                    if (!taggerImpl.shrink()) {
                        System.err.println("fail to build feature index ");
                        return false;
                    }
                    taggerImpl.setThread_id_(i6 % i4);
                    arrayList.add(taggerImpl);
                    i6++;
                    if (i6 % 100 == 0) {
                        System.out.print(i6 + ".. ");
                    }
                }
            }
        } catch (IOException unused) {
            System.err.println("train file " + str2 + " does not exist.");
            return false;
        }
    }

    public boolean runMIRA(List<TaggerImpl> list, EncoderFeatureIndex encoderFeatureIndex, double[] dArr, int i2, double d2, double d3, int i3, int i4) {
        Integer num;
        int i5;
        int i6;
        Double d4;
        List<TaggerImpl> list2 = list;
        int i7 = i2;
        double d5 = d2;
        Integer[] numArr = new Integer[list.size()];
        int i8 = 0;
        Integer num2 = 0;
        Arrays.fill(numArr, num2);
        List asList = Arrays.asList(numArr);
        Double[] dArr2 = new Double[list.size()];
        Double valueOf = Double.valueOf(ShadowDrawableWrapper.r);
        Arrays.fill(dArr2, valueOf);
        List asList2 = Arrays.asList(dArr2);
        List<Double> asList3 = Arrays.asList(new Double[encoderFeatureIndex.size()]);
        if (i4 > 1) {
            System.err.println("WARN: MIRA does not support multi-threading");
        }
        int i9 = 0;
        for (int i10 = 0; i10 < list.size(); i10++) {
            i9 += list2.get(i10).size();
        }
        boolean z = true;
        double d6 = 0.0d;
        int i11 = 0;
        int i12 = 0;
        while (i8 < i7) {
            int i13 = i11;
            int i14 = i13;
            int i15 = i12;
            double d7 = d6;
            int i16 = i14;
            int i17 = i16;
            while (i11 < list.size()) {
                int i18 = i9;
                if (((Integer) asList.get(i11)).intValue() < i3) {
                    int i19 = i16 + 1;
                    for (int i20 = 0; i20 < asList3.size(); i20++) {
                        asList3.set(i20, valueOf);
                    }
                    double collins = list2.get(i11).collins(asList3);
                    int eval = list2.get(i11).eval();
                    i13 += eval;
                    if (eval != 0) {
                        i14++;
                    }
                    if (eval == 0) {
                        asList.set(i11, Integer.valueOf(((Integer) asList.get(i11)).intValue() + 1));
                        i16 = i19;
                    } else {
                        asList.set(i11, num2);
                        double d8 = ShadowDrawableWrapper.r;
                        for (int i21 = 0; i21 < asList3.size(); i21++) {
                            d8 = (asList3.get(i21).doubleValue() * asList3.get(i21).doubleValue()) + d8;
                        }
                        double d9 = eval - collins;
                        d4 = valueOf;
                        int i22 = i14;
                        double max = Math.max(ShadowDrawableWrapper.r, d9 / d8);
                        if (((Double) asList2.get(i11)).doubleValue() + max > d5) {
                            max = d5 - ((Double) asList2.get(i11)).doubleValue();
                            i17++;
                        } else {
                            d7 = Math.max(d9, d7);
                        }
                        if (max > 1.0E-10d) {
                            asList2.set(i11, Double.valueOf(((Double) asList2.get(i11)).doubleValue() + max));
                            asList2.set(i11, Double.valueOf(Math.min(d5, ((Double) asList2.get(i11)).doubleValue())));
                            for (int i23 = 0; i23 < asList3.size(); i23++) {
                                dArr[i23] = (asList3.get(i23).doubleValue() * max) + dArr[i23];
                            }
                        }
                        i13 = i13;
                        i14 = i22;
                        i16 = i19;
                        i11++;
                        list2 = list;
                        i9 = i18;
                        valueOf = d4;
                    }
                }
                d4 = valueOf;
                i11++;
                list2 = list;
                i9 = i18;
                valueOf = d4;
            }
            Double d10 = valueOf;
            int i24 = i9;
            double d11 = d7;
            double d12 = 0.0d;
            for (int i25 = 0; i25 < encoderFeatureIndex.size(); i25++) {
                d12 = (dArr[i25] * dArr[i25]) + d12;
            }
            StringBuilder E = a.E("iter=", i8, " terr=");
            List<Double> list3 = asList3;
            int i26 = i8;
            Integer num3 = num2;
            E.append((i13 * 1.0d) / i24);
            E.append(" serr=");
            E.append((i14 * 1.0d) / list.size());
            E.append(" act=");
            E.append(i16);
            E.append(" uact=");
            E.append(i17);
            E.append(" obj=");
            E.append(d12);
            E.append(" kkt=");
            E.append(d11);
            System.out.println(E.toString());
            d6 = ShadowDrawableWrapper.r;
            if (d11 <= ShadowDrawableWrapper.r) {
                for (int i27 = 0; i27 < asList.size(); i27++) {
                    asList.set(i27, num3);
                }
                num = num3;
                i12 = i15 + 1;
                i6 = i2;
                i5 = i26;
            } else {
                num = num3;
                i12 = 0;
                i5 = i26;
                i6 = i2;
            }
            if (i5 > i6 || i12 == 2) {
                return true;
            }
            i8 = i5 + 1;
            i11 = 0;
            z = true;
            num2 = num;
            asList3 = list3;
            valueOf = d10;
            d5 = d2;
            i9 = i24;
            i7 = i6;
            list2 = list;
        }
        return z;
    }
}
