﻿using System;
using System.IO;

namespace LinearRidgeRegressNumeric
{
  internal class LinearRidgeRegressNumericProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nLinear ridge regression demo ");

      // 1. load data from file to memory
      Console.WriteLine("\nLoading train and test data ");
      string trainFile = 
        "..\\..\\..\\Data\\synthetic_train.txt";
      double[][] trainX =
        Utils.MatLoad(trainFile,
        new int[] { 0, 1, 2 }, ',', "#");
      double[] trainY =
        Utils.MatToVec(Utils.MatLoad(trainFile,
        new int[] { 3 }, ',', "#"));

      string testFile = 
        "..\\..\\..\\Data\\synthetic_test.txt";
      double[][] testX =
        Utils.MatLoad(testFile,
        new int[] { 0, 1, 2 }, ',', "#");
      double[] testY =
        Utils.MatToVec(Utils.MatLoad(testFile,
        new int[] { 3 }, ',', "#"));

      Console.WriteLine("\nFirst three train X inputs: ");
      for (int i = 0; i < 3; ++i)
        Utils.VecShow(trainX[i], 4, 9, true);

      Console.WriteLine("\nFirst three train y targets: ");
      for (int i = 0; i < 3; ++i)
        Console.WriteLine(trainY[i].ToString("F4"));

      // 2. create LR model
      double alpha = 0.05;
      Console.WriteLine("\nCreating and training LRR model" +
        " with alpha = " + alpha.ToString("F3"));
      Console.WriteLine("Coefficients = " +
        " inv(DXt * DX) * DXt * Y ");
      LinearRidgeRegress lrrModel = 
        new LinearRidgeRegress(alpha);

      // 3. train model
      lrrModel.Train(trainX, trainY);
      Console.WriteLine("Done. Model constant," +
        " coefficients = ");
      Utils.VecShow(lrrModel.coeffs, 4, 9, true);

      // 4. evaluate model
      Console.WriteLine("\nComputing model accuracy ");
      double accTrain = 
        Accuracy(lrrModel, trainX, trainY, 0.10);
      Console.WriteLine("\nAccuracy on train (0.10) = "
        + accTrain.ToString("F4"));

      double accTest = 
        Accuracy(lrrModel, testX, testY, 0.10);
      Console.WriteLine("Accuracy on test (0.10) = " +
        accTest.ToString("F4"));

      double rmseTrain =
        RootMSE(lrrModel, trainX, trainY);
      Console.WriteLine("\nRMSE on train = "
        + rmseTrain.ToString("F4"));

      double rmseTest =
        RootMSE(lrrModel, testX, testY);
      Console.WriteLine("RMSE on test = " +
        rmseTest.ToString("F4"));

      // 5. use model
      Console.WriteLine("\nPredicting x = 0.5, -0.6, 0.7 ");
      double[] x = new double[] { 0.5, -0.6, 0.7 };
      double predY = lrrModel.Predict(x);
      Console.WriteLine("\nPredicted y = " +
        predY.ToString("F4"));

      // 6. TODO: save model to file

      Console.WriteLine("\nEnd linear ridge regression ");
      Console.ReadLine();
    } // Main

    static double Accuracy(LinearRidgeRegress model,
      double[][] dataX, double[] dataY, double pctClose)
    {
      int numCorrect = 0; int numWrong = 0;
      int n = dataX.Length;
      for (int i = 0; i < n; ++i)
      {
        double[] x = dataX[i];
        double yActual = dataY[i];
        double yPred = model.Predict(x);
        if (Math.Abs(yActual - yPred) <
          Math.Abs(pctClose * yActual))
          ++numCorrect;
        else
          ++numWrong;
      }
      return (numCorrect * 1.0) / n;
    }

    static double RootMSE(LinearRidgeRegress model,
      double[][] dataX, double[] dataY)
    {
      double sum = 0.0;
      int n = dataX.Length;
      for (int i = 0; i < n; ++i)
      {
        double[] x = dataX[i];
        double yActual = dataY[i];
        double yPred = model.Predict(x);
        sum += (yActual - yPred) * (yActual - yPred);
      }
      return Math.Sqrt(sum / n);
    }

  } // Program

  public class LinearRidgeRegress
  {
    public double[]? coeffs;
    public double alpha;

    public LinearRidgeRegress(double alpha)
    {
      this.coeffs = null;
      this.alpha = alpha;  // ridge regression noise
    }

    public void Train(double[][] trainX, double[] trainY)
    {
      double[][] DX = 
        Utils.MatMakeDesign(trainX);  // add 1s column

      // coeffs = inv(DXt * DX) * DXt * Y 
      double[][] a = Utils.MatTranspose(DX);
      double[][] b = Utils.MatProduct(a, DX);

      for (int i = 0; i < b.Length; ++i)  // ridge regression
        b[i][i] += this.alpha;

      double[][] c = Utils.MatInverse(b); 
      double[][] d = Utils.MatProduct(c, a);
      double[][] Y = Utils.VecToMat(trainY, DX.Length, 1);
      double[][] e = Utils.MatProduct(d, Y);
      this.coeffs = Utils.MatToVec(e);
    }

    public double Predict(double[] x)
    {
      // constant at coeffs[0]
      double sum = 0.0;
      int n = x.Length;  // number predictors
      for (int i = 0; i < n; ++i)
        sum += x[i] * this.coeffs[i + 1];

      sum += this.coeffs[0];  // add the constant
      return sum;
    }

  } // class LinearRegression

  public class Utils
  {
    public static double[][] VecToMat(double[] vec,
      int nRows, int nCols)
    {
      double[][] result = MatCreate(nRows, nCols);
      int k = 0;
      for (int i = 0; i < nRows; ++i)
        for (int j = 0; j < nCols; ++j)
          result[i][j] = vec[k++];
      return result;
    }

    public static double[][] MatCreate(int rows, int cols)
    {
      double[][] result = new double[rows][];
      for (int i = 0; i < rows; ++i)
        result[i] = new double[cols];
      return result;
    }

    public static double[] MatToVec(double[][] m)
    {
      int rows = m.Length;
      int cols = m[0].Length;
      double[] result = new double[rows * cols];
      int k = 0;
      for (int i = 0; i < rows; ++i)
        for (int j = 0; j < cols; ++j)
        {
          result[k++] = m[i][j];
        }
      return result;
    }

    public static double[][] MatProduct(double[][] matA,
      double[][] matB)
    {
      int aRows = matA.Length;
      int aCols = matA[0].Length;
      int bRows = matB.Length;
      int bCols = matB[0].Length;
      if (aCols != bRows)
        throw new Exception("Non-conformable matrices");

      double[][] result = MatCreate(aRows, bCols);

      for (int i = 0; i < aRows; ++i) // each row of A
        for (int j = 0; j < bCols; ++j) // each col of B
          for (int k = 0; k < aCols; ++k)
            result[i][j] += matA[i][k] * matB[k][j];

      return result;
    }

    public static double[][] MatTranspose(double[][] m)
    {
      int nr = m.Length;
      int nc = m[0].Length;
      double[][] result = MatCreate(nc, nr);  // note
      for (int i = 0; i < nr; ++i)
        for (int j = 0; j < nc; ++j)
          result[j][i] = m[i][j];
      return result;
    }

    // -------

    public static double[][] MatInverse(double[][] m)
    {
      // assumes determinant is not 0
      // that is, the matrix does have an inverse
      int n = m.Length;
      double[][] result = MatCreate(n, n); // make a copy
      for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
          result[i][j] = m[i][j];

      double[][] lum; // combined lower & upper
      int[] perm;  // out parameter
      MatDecompose(m, out lum, out perm);  // ignore return

      double[] b = new double[n];
      for (int i = 0; i < n; ++i)
      {
        for (int j = 0; j < n; ++j)
          if (i == perm[j])
            b[j] = 1.0;
          else
            b[j] = 0.0;

        double[] x = Reduce(lum, b); // 
        for (int j = 0; j < n; ++j)
          result[j][i] = x[j];
      }
      return result;
    }

    private static int MatDecompose(double[][] m,
      out double[][] lum, out int[] perm)
    {
      // Crout's LU decomposition for matrix determinant
      // and inverse.
      // stores combined lower & upper in lum[][]
      // stores row permuations into perm[]
      // returns +1 or -1 according to even or odd number
      // of row permutations.
      // lower gets dummy 1.0s on diagonal (0.0s above)
      // upper gets lum values on diagonal (0.0s below)

      // even (+1) or odd (-1) row permutatuions
      int toggle = +1;
      int n = m.Length;

      // make a copy of m[][] into result lu[][]
      lum = MatCreate(n, n);
      for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
          lum[i][j] = m[i][j];

      // make perm[]
      perm = new int[n];
      for (int i = 0; i < n; ++i)
        perm[i] = i;

      for (int j = 0; j < n - 1; ++j) // note n-1 
      {
        double max = Math.Abs(lum[j][j]);
        int piv = j;

        for (int i = j + 1; i < n; ++i) // pivot index
        {
          double xij = Math.Abs(lum[i][j]);
          if (xij > max)
          {
            max = xij;
            piv = i;
          }
        } // i

        if (piv != j)
        {
          double[] tmp = lum[piv]; // swap rows j, piv
          lum[piv] = lum[j];
          lum[j] = tmp;

          int t = perm[piv]; // swap perm elements
          perm[piv] = perm[j];
          perm[j] = t;

          toggle = -toggle;
        }

        double xjj = lum[j][j];
        if (xjj != 0.0)
        {
          for (int i = j + 1; i < n; ++i)
          {
            double xij = lum[i][j] / xjj;
            lum[i][j] = xij;
            for (int k = j + 1; k < n; ++k)
              lum[i][k] -= xij * lum[j][k];
          }
        }

      } // j

      return toggle;  // for determinant
    } // MatDecompose

    private static double[] Reduce(double[][] luMatrix,
      double[] b) // helper
    {
      int n = luMatrix.Length;
      double[] x = new double[n];
      //b.CopyTo(x, 0);
      for (int i = 0; i < n; ++i)
        x[i] = b[i];

      for (int i = 1; i < n; ++i)
      {
        double sum = x[i];
        for (int j = 0; j < i; ++j)
          sum -= luMatrix[i][j] * x[j];
        x[i] = sum;
      }

      x[n - 1] /= luMatrix[n - 1][n - 1];
      for (int i = n - 2; i >= 0; --i)
      {
        double sum = x[i];
        for (int j = i + 1; j < n; ++j)
          sum -= luMatrix[i][j] * x[j];
        x[i] = sum / luMatrix[i][i];
      }

      return x;
    } // Reduce

    // -------

    private static int NumNonCommentLines(string fn,
    string comment)
    {
      int ct = 0;
      string line = "";
      FileStream ifs = new FileStream(fn,
        FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
        if (line.StartsWith(comment) == false)
          ++ct;
      sr.Close(); ifs.Close();
      return ct;
    }

    public static double[][] MatLoad(string fn,
      int[] usecols, char sep, string comment)
    {
      // count number of non-comment lines
      int nRows = NumNonCommentLines(fn, comment);
      int nCols = usecols.Length;
      double[][] result = MatCreate(nRows, nCols);
      string line = "";
      string[] tokens = null;
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);

      int i = 0;
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        tokens = line.Split(sep);
        for (int j = 0; j < nCols; ++j)
        {
          int k = usecols[j];  // into tokens
          result[i][j] = double.Parse(tokens[k]);
        }
        ++i;
      }
      sr.Close(); ifs.Close();
      return result;
    }

    public static double[][] MatMakeDesign(double[][] m)
    {
      // add a leading column of 1s to create a design matrix
      int nRows = m.Length; int nCols = m[0].Length;
      double[][] result = MatCreate(nRows, nCols + 1);
      for (int i = 0; i < nRows; ++i)
        result[i][0] = 1.0;

      for (int i = 0; i < nRows; ++i)
      {
        for (int j = 0; j < nCols; ++j)
        {
          result[i][j + 1] = m[i][j];
        }
      }
      return result;
    }

    public static void MatShow(double[][] m,
      int dec, int wid)
    {
      for (int i = 0; i < m.Length; ++i)
      {
        for (int j = 0; j < m[0].Length; ++j)
        {
          double v = m[i][j];
          if (Math.Abs(v) < 1.0e-8) v = 0.0; // hack
          Console.Write(v.ToString("F" +
            dec).PadLeft(wid));
        }
        Console.WriteLine("");
      }
    }

    public static void VecShow(double[] vec,
      int dec, int wid, bool newLine)
    {
      for (int i = 0; i < vec.Length; ++i)
      {
        double x = vec[i];
        if (Math.Abs(x) < 1.0e-8) x = 0.0;  // hack
        Console.Write(x.ToString("F" +
          dec).PadLeft(wid));
      }
      if (newLine == true)
        Console.WriteLine("");
    }

  } // class Utils

} // ns