/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.scripts.nn.test;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Script;

public class Util
extends Script {
    public Util() {
        String string = "scripts/nn/test/util.dml";
        InputStream inputStream = Script.class.getResourceAsStream(new StringBuffer().append("/").append(string).toString());
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
        char[] cArray = new char[1024];
        StringBuilder stringBuilder = new StringBuilder();
        try {
            int n;
            while ((n = inputStreamReader.read(cArray)) > 0) {
                stringBuilder.append(cArray, 0, n);
            }
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
        this.setScriptString(stringBuilder.toString());
    }

    public boolean check_all_equal(Object object, Object object2) {
        String string = "source('scripts/nn/test/util.dml') as mlcontextns;equivalent = mlcontextns::check_all_equal(X1, X2);";
        Script script = new Script(string);
        script.in("X1", object).in("X2", object2).out("equivalent");
        MLResults mLResults = script.execute();
        boolean bl = mLResults.getBoolean("equivalent");
        return bl;
    }

    public String check_all_equal__docs() {
        String string = "check_all_equal = function(matrix[double] X1, matrix[double] X2)\n    return(boolean equivalent) {\n  /*\n   * Check if two matrices are equivalent, and report any issues.\n   *\n   * Issues an \"ERROR\" statement if elements of the two matrices are\n   * not equal.\n   *\n   * Inputs:\n   *  - X1: Inputs, of shape (any, any).\n   *  - X2: Inputs, of same shape as X1.\n   *\n   * Outputs:\n   *  - equivalent: Whether or not the two matrices are equivalent.\n   */\n";
        return string;
    }

    public String check_all_equal__source() {
        String string = "check_all_equal = function(matrix[double] X1, matrix[double] X2)\n    return(boolean equivalent) {\n  /*\n   * Check if two matrices are equivalent, and report any issues.\n   *\n   * Issues an \"ERROR\" statement if elements of the two matrices are\n   * not equal.\n   *\n   * Inputs:\n   *  - X1: Inputs, of shape (any, any).\n   *  - X2: Inputs, of same shape as X1.\n   *\n   * Outputs:\n   *  - equivalent: Whether or not the two matrices are equivalent.\n   */\n  # Determine if matrices are equivalent\n  equivalent = all_equal(X1, X2)\n\n  # Evaluate relative error\n  if (!equivalent) {\n    print(\"ERROR: The two matrices are not equivalent.\")\n  }\n}\n";
        return string;
    }

    public boolean all_equal(Object object, Object object2) {
        String string = "source('scripts/nn/test/util.dml') as mlcontextns;equivalent = mlcontextns::all_equal(X1, X2);";
        Script script = new Script(string);
        script.in("X1", object).in("X2", object2).out("equivalent");
        MLResults mLResults = script.execute();
        boolean bl = mLResults.getBoolean("equivalent");
        return bl;
    }

    public String all_equal__docs() {
        String string = "all_equal = function(matrix[double] X1, matrix[double] X2)\n    return(boolean equivalent) {\n  /*\n   * Determine if two matrices are equivalent.\n   *\n   * Inputs:\n   *  - X1: Inputs, of shape (any, any).\n   *  - X2: Inputs, of same shape as X1.\n   *\n   * Outputs:\n   *  - equivalent: Whether or not the two matrices are equivalent.\n   */\n";
        return string;
    }

    public String all_equal__source() {
        String string = "all_equal = function(matrix[double] X1, matrix[double] X2)\n    return(boolean equivalent) {\n  /*\n   * Determine if two matrices are equivalent.\n   *\n   * Inputs:\n   *  - X1: Inputs, of shape (any, any).\n   *  - X2: Inputs, of same shape as X1.\n   *\n   * Outputs:\n   *  - equivalent: Whether or not the two matrices are equivalent.\n   */\n  equivalent = as.logical(prod(X1 == X2))\n}\n";
        return string;
    }

    public double compute_rel_error(Object object, Object object2) {
        String string = "source('scripts/nn/test/util.dml') as mlcontextns;rel_error = mlcontextns::compute_rel_error(x1, x2);";
        Script script = new Script(string);
        script.in("x1", object).in("x2", object2).out("rel_error");
        MLResults mLResults = script.execute();
        double d = mLResults.getDouble("rel_error");
        return d;
    }

    public String compute_rel_error__docs() {
        String string = "compute_rel_error = function(double x1, double x2)\n    return (double rel_error) {\n  /*\n   * Relative error measure between two values.\n   *\n   * Uses smoothing to avoid divide-by-zero errors.\n   *\n   * Inputs:\n   *  - x1: First value.\n   *  - x2: Second value.\n   *\n   * Outputs:\n   *  - rel_error: Relative error measure between the two values.\n   */\n";
        return string;
    }

    public String compute_rel_error__source() {
        String string = "compute_rel_error = function(double x1, double x2)\n    return (double rel_error) {\n  /*\n   * Relative error measure between two values.\n   *\n   * Uses smoothing to avoid divide-by-zero errors.\n   *\n   * Inputs:\n   *  - x1: First value.\n   *  - x2: Second value.\n   *\n   * Outputs:\n   *  - rel_error: Relative error measure between the two values.\n   */\n  rel_error = abs(x1-x2) / max(1e-8, abs(x1)+abs(x2))\n}\n";
        return string;
    }

    public double check_rel_error(Object object, Object object2, Object object3, Object object4) {
        String string = "source('scripts/nn/test/util.dml') as mlcontextns;rel_error = mlcontextns::check_rel_error(x1, x2, thresh_error, thresh_warn);";
        Script script = new Script(string);
        script.in("x1", object).in("x2", object2).in("thresh_error", object3).in("thresh_warn", object4).out("rel_error");
        MLResults mLResults = script.execute();
        double d = mLResults.getDouble("rel_error");
        return d;
    }

    public String check_rel_error__docs() {
        String string = "check_rel_error = function(double x1, double x2, double thresh_error, double thresh_warn)\n    return (double rel_error) {\n  /*\n   * Check and report any issues with the relative error measure between\n   * two values.\n   *\n   * Issues an \"ERROR\" statement for relative errors > thresh_error,\n   * indicating that the implementation is likely incorrect.\n   *\n   * Issues a \"WARNING\" statement for relative errors < thresh_error\n   * but > thresh_warn, indicating that the implementation may be\n   * incorrect.\n   *\n   * Inputs:\n   *  - x1: First value.\n   *  - x2: Second value.\n   *  - thresh_error: Error threshold.\n   *  - thresh_warn: Warning threshold.\n   *\n   * Outputs:\n   *  - rel_error: Relative error measure between the two values.\n   */\n";
        return string;
    }

    public String check_rel_error__source() {
        String string = "check_rel_error = function(double x1, double x2, double thresh_error, double thresh_warn)\n    return (double rel_error) {\n  /*\n   * Check and report any issues with the relative error measure between\n   * two values.\n   *\n   * Issues an \"ERROR\" statement for relative errors > thresh_error,\n   * indicating that the implementation is likely incorrect.\n   *\n   * Issues a \"WARNING\" statement for relative errors < thresh_error\n   * but > thresh_warn, indicating that the implementation may be\n   * incorrect.\n   *\n   * Inputs:\n   *  - x1: First value.\n   *  - x2: Second value.\n   *  - thresh_error: Error threshold.\n   *  - thresh_warn: Warning threshold.\n   *\n   * Outputs:\n   *  - rel_error: Relative error measure between the two values.\n   */\n  # Compute relative error\n  rel_error = compute_rel_error(x1, x2)\n\n  # Evaluate relative error\n  if (rel_error > thresh_error) {\n    print(\"ERROR: Relative error \" + rel_error + \" > \" + thresh_error + \" with \" + x1 +\n          \" vs \" + x2 + \".\")\n  }\n  else if (rel_error > thresh_warn & rel_error <= thresh_error) {\n    print(\"WARNING: Relative error \" + rel_error + \" > \" + thresh_warn + \" & <= \" + thresh_error +\n          \" with \" + x1 + \" vs \" + x2 + \".\")\n  }\n}\n";
        return string;
    }

    public double check_rel_grad_error(Object object, Object object2, Object object3, Object object4) {
        String string = "source('scripts/nn/test/util.dml') as mlcontextns;rel_error = mlcontextns::check_rel_grad_error(dw_a, dw_n, lossph, lossmh);";
        Script script = new Script(string);
        script.in("dw_a", object).in("dw_n", object2).in("lossph", object3).in("lossmh", object4).out("rel_error");
        MLResults mLResults = script.execute();
        double d = mLResults.getDouble("rel_error");
        return d;
    }

    public String check_rel_grad_error__docs() {
        String string = "check_rel_grad_error = function(double dw_a, double dw_n, double lossph, double lossmh)\n    return (double rel_error) {\n  /*\n   * Check and report any issues with the relative error measure between\n   * the analytical and numerical partial derivatives.\n   *\n   *  - Issues an \"ERROR\" statement for relative errors > 1e-2,\n   *  indicating that the gradient is likely incorrect.\n   *  - Issues a \"WARNING\" statement for relative errors < 1e-2\n   *  but > 1e-4, indicating that the may be incorrect.\n   *\n   * Inputs:\n   *  - dw_a: Analytical partial derivative wrt w.\n   *  - dw_n: Numerical partial derivative wrt w.\n   *  - lossph: Loss evaluated with w set to w+h.\n   *  - lossmh: Loss evaluated with w set to w-h.\n   *\n   * Outputs:\n   *  - rel_error: Relative error measure between the two derivatives.\n   */\n";
        return string;
    }

    public String check_rel_grad_error__source() {
        String string = "check_rel_grad_error = function(double dw_a, double dw_n, double lossph, double lossmh)\n    return (double rel_error) {\n  /*\n   * Check and report any issues with the relative error measure between\n   * the analytical and numerical partial derivatives.\n   *\n   *  - Issues an \"ERROR\" statement for relative errors > 1e-2,\n   *  indicating that the gradient is likely incorrect.\n   *  - Issues a \"WARNING\" statement for relative errors < 1e-2\n   *  but > 1e-4, indicating that the may be incorrect.\n   *\n   * Inputs:\n   *  - dw_a: Analytical partial derivative wrt w.\n   *  - dw_n: Numerical partial derivative wrt w.\n   *  - lossph: Loss evaluated with w set to w+h.\n   *  - lossmh: Loss evaluated with w set to w-h.\n   *\n   * Outputs:\n   *  - rel_error: Relative error measure between the two derivatives.\n   */\n  # Compute relative error\n  rel_error = compute_rel_error(dw_a, dw_n)\n\n  # Evaluate relative error\n  thresh_error = 1e-2\n  thresh_warn = 1e-4\n  if (rel_error > thresh_error) {\n    print(\"ERROR: Relative error \" + rel_error + \" > \" + thresh_error + \" with \" + dw_a +\n          \" analytical vs \" + dw_n + \" numerical, with lossph \" + lossph +\n          \" and lossmh \" + lossmh)\n  }\n  else if (rel_error > thresh_warn & rel_error <= thresh_error) {\n    print(\"WARNING: Relative error \" + rel_error + \" > \" + thresh_warn + \" & <= \" + thresh_error +\n          \" with \" + dw_a + \" analytical vs \" + dw_n + \" numerical, with lossph \" + lossph +\n          \" and lossmh \" + lossmh)\n  }\n}\n";
        return string;
    }
}

