﻿using System;
using System.IO;

namespace KernelRidgeRegression
{
  internal class KernelRidgeRegressionProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nBegin C# kernel ridge regression ");

      Console.WriteLine("\nLoading train and test data ");
      string trainFile =
        "..\\..\\..\\Data\\synthetic_train.txt";
      double[][] trainX =
        Utils.MatLoad(trainFile,
        new int[] { 0, 1, 2, 3, 4 }, '\t', "#");  // 40 items
      double[] trainY =
        Utils.MatToVec(Utils.MatLoad(trainFile,
        new int[] { 5 }, '\t', "#"));

      string testFile =
        "..\\..\\..\\Data\\synthetic_test.txt";
      double[][] testX =
        Utils.MatLoad(testFile,
        new int[] { 0, 1, 2, 3, 4 }, '\t', "#");  // 10 items
      double[] testY =
        Utils.MatToVec(Utils.MatLoad(testFile,
        new int[] { 5 }, '\t', "#"));
      Console.WriteLine("Done ");

      Console.WriteLine("\nFirst three X predictors: ");
      for (int i = 0; i < 3; ++i)
        Utils.VecShow(trainX[i], 4, 9, true);
      Console.WriteLine("\nFirst three target y: ");
      for (int i = 0; i < 3; ++i)
        Console.WriteLine(trainY[i].ToString("F4").PadLeft(8));

      // explore hyperparams

      //double[] gammas = new double[]
      //  { 0.1, 0.2, 0.3, 0.5, 1.0, 1.5, 2.0 };
      //double[] alphas = new double[]
      //  { 0.0, 0.0001, 0.001, 0.01, 0.05, 0.1, 0.5 };
      //foreach (double g in gammas)
      //{
      //  foreach (double a in alphas)
      //  {
      //    KRR krr = new KRR(g, a);
      //    krr.Train(trainX, trainY);

      //    double trainRMSE = RootMSE(krr, trainX, trainY);
      //    double testRMSE = RootMSE(krr, testX, testY);
      //    double trainAcc = Accuracy(krr, trainX, trainY, 0.10);
      //    double testAcc = Accuracy(krr, testX, testY, 0.10);

      //    Console.WriteLine("gamma = " + g.ToString("F1") +
      //      " alpha = " + a.ToString("F4") +
      //      " train rmse = " + trainRMSE.ToString("F4") +
      //      " test rmse = " + testRMSE.ToString("F4") +
      //      " train acc = " + trainAcc.ToString("F4") +
      //      " test acc = " + testAcc.ToString("F4"));
      //  }
      //}

      // gamma = 0.1 alpha = 0.001
      // train rmse = 0.0030 test rmse = 0.0308
      // train acc = 1.0000 test acc = 0.8000

      Console.WriteLine("\nCreating KRR object");
      double gamma = 0.1;    // RBF param
      double alpha = 0.001;  // regularization
      Console.WriteLine("Setting RBF gamma = " +
        gamma.ToString("F1"));
      Console.WriteLine("Setting alpha noise =  " +
        alpha.ToString("F3"));
      KRR krr = new KRR(gamma, alpha);
      Console.WriteLine("Done ");

      Console.WriteLine("\nTraining model ");
      krr.Train(trainX, trainY);
      Console.WriteLine("Done ");
      // Console.WriteLine("Model weights: ");  // 1 per trainX
      // Utils.VecShow(krr.wts, 4, 9, true);

      Console.WriteLine("\nComputing model accuracy" +
        " (within 0.10) ");
      double trainAcc = Accuracy(krr, trainX, trainY, 0.10);
      double testAcc = Accuracy(krr, testX, testY, 0.10);

      Console.WriteLine("\nTrain acc = " +
        trainAcc.ToString("F4"));
      Console.WriteLine("Test acc = " +
        testAcc.ToString("F4"));

      double trainRMSE = RootMSE(krr, trainX, trainY);
      double testRMSE = RootMSE(krr, testX, testY);

      Console.WriteLine("\nTrain RMSE = " +
        trainRMSE.ToString("F4"));
      Console.WriteLine("Test RMSE = " +
        testRMSE.ToString("F4"));

      Console.WriteLine("\nPredicting for x = " +
        "[0.5, -0.5, 0.5, -0.5, 0.5] ");
      double[] x =
        new double[] { 0.5, -0.5, 0.5, -0.5, 0.5 };

      double y = krr.Predict(x); 
      Console.WriteLine("Predicted y = " + y.ToString("F4"));

      Console.WriteLine("\nEnd KRR demo ");
      Console.ReadLine();
    } // Main()

    static double Accuracy(KRR model, double[][] dataX,
      double[] dataY, double pctClose)
    {
      int numCorrect = 0; int numWrong = 0;
      int n = dataX.Length;
      for (int i = 0; i < n; ++i)
      {
        double[] x = dataX[i];
        double actualY = dataY[i];
        double predY = model.Predict(x);
        if (Math.Abs(actualY - predY) <
          Math.Abs(actualY * pctClose))
          ++numCorrect;
        else
          ++numWrong;
      }
      return (numCorrect * 1.0) / n;
    }

    static double RootMSE(KRR model, double[][] dataX,
      double[] dataY)
    {
      double sum = 0.0;
      int n = dataX.Length;
      for (int i = 0; i < n; ++i)
      {
        double[] x = dataX[i];
        double actualY = dataY[i];
        double predY = model.Predict(x);
        sum += (actualY - predY) * (actualY - predY);
      }
      return Math.Sqrt(sum / n);
    }

  } // class Program

  public class KRR
  {
    public double gamma;  // for RBF kernel
    public double alpha;  // regularization
    public double[][] trainX;  // need for any prediction
    public double[] wts;  // one per trainX item

    public KRR(double gamma, double alpha)
    {
      this.gamma = gamma;
      this.alpha = alpha;
    } // ctor

    public void Train(double[][] trainX, double[] trainY)
    {
      // 0. store trainX -- needed by Predict()
      this.trainX = trainX;  // by ref -- could copy

      // 1. compute train-train K matrix
      int N = trainX.Length;
      double[][] K = Utils.MatCreate(N, N);
      for (int i = 0; i < N; ++i)
        for (int j = 0; j < N; ++j)
          K[i][j] = this.Rbf(trainX[i], trainX[j]);

      // 2. add regularization on diagonal
      for (int i = 0; i < N; ++i)
        K[i][i] += this.alpha;

      // 3. compute model weights using K inverse
      double[][] Kinv = Utils.MatInverse(K);
      this.wts = Utils.VecMatProd(trainY, Kinv);

      // alt approach
      //double[][] tmp1 = Utils.VecToMat(trainY, 1, N);
      //double[][] tmp2 = Utils.MatProduct(tmp1, Kinv);
      //this.wts = Utils.MatToVec(tmp2);

    } // Train

    public double Rbf(double[] v1, double[] v2)
    {
      // the gamma version aot len_scale version
      int dim = v1.Length;
      double sum = 0.0;
      for (int i = 0; i < dim; ++i)
      {
        sum += (v1[i] - v2[i]) * (v1[i] - v2[i]);
      }
      return Math.Exp(-1 * this.gamma * sum);  // before
    }

    public double Predict(double[] x)
    {
      int N = this.trainX.Length;
      double sum = 0.0;
      for (int i = 0; i < N; ++i)
      {
        double[] xx = this.trainX[i];
        double k = this.Rbf(x, xx);
        sum += this.wts[i] * k;
      }
      return sum;

      // alt:
      // compute K(x, X')
      //double[][] K = Utils.MatCreate(N, 1);
      //for (int i = 0; i < N; ++i)
      //  K[i][0] = this.Rbf(x, this.trainX[i]);
      //double[][] y = Utils.MatProduct(Utils.VecToMat(
      //  this.wts, 1, N), K);
      //return y[0][0];
    }

  } // class KRR

  public class Utils
  {
    public static double[][] MatCreate(int rows, int cols)
    {
      double[][] result = new double[rows][];
      for (int i = 0; i < rows; ++i)
        result[i] = new double[cols];
      return result;
    }

    public static double[] MatToVec(double[][] m)
    {
      int rows = m.Length;
      int cols = m[0].Length;
      double[] result = new double[rows * cols];
      int k = 0;
      for (int i = 0; i < rows; ++i)
        for (int j = 0; j < cols; ++j)
        {
          result[k++] = m[i][j];
        }
      return result;
    }

    public static double[][] MatProduct(double[][] matA,
      double[][] matB)
    {
      int aRows = matA.Length;
      int aCols = matA[0].Length;
      int bRows = matB.Length;
      int bCols = matB[0].Length;
      if (aCols != bRows)
        throw new Exception("Non-conformable matrices");

      double[][] result = MatCreate(aRows, bCols);

      for (int i = 0; i < aRows; ++i) // each row of A
        for (int j = 0; j < bCols; ++j) // each col of B
          for (int k = 0; k < aCols; ++k)
            result[i][j] += matA[i][k] * matB[k][j];

      return result;
    }

    public static double[][] VecToMat(double[] vec,
      int nRows, int nCols)
    {
      double[][] result = MatCreate(nRows, nCols);
      int k = 0;
      for (int i = 0; i < nRows; ++i)
        for (int j = 0; j < nCols; ++j)
          result[i][j] = vec[k++];
      return result;
    }

    public static double[] VecMatProd(double[] v,
      double[][] m)
    {
      // one-dim vec * two-dim mat
      int nRows = m.Length;
      int nCols = m[0].Length;
      int n = v.Length;
      if (n != nCols)
        throw new Exception("non-comform in VecMatProd");

      double[] result = new double[n];
      for (int i = 0; i < n; ++i)
      {
        for (int j = 0; j < nCols; ++j)
        {
          result[i] += v[j] * m[i][j];
        }
      }
      return result;
    }

    //public static double[] MatVecProd(double[][] mat,
    //  double[] vec)
    //{
    //  int nRows = mat.Length; int nCols = mat[0].Length;
    //  if (vec.Length != nCols)
    //    throw new Exception("Non-conforme in MatVecProd() ");
    //  double[] result = new double[nRows];
    //  for (int i = 0; i < nRows; ++i)
    //  {
    //    for (int j = 0; j < nCols; ++j)
    //    {
    //      result[i] += mat[i][j] * vec[j];
    //    }
    //  }
    //  return result;
    //}

    public static double[][] MatTranspose(double[][] m)
    {
      int nr = m.Length;
      int nc = m[0].Length;
      double[][] result = MatCreate(nc, nr);  // note
      for (int i = 0; i < nr; ++i)
        for (int j = 0; j < nc; ++j)
          result[j][i] = m[i][j];
      return result;
    }

    // -------

    public static double[][] MatInverse(double[][] m)
    {
      // assumes determinant is not 0
      // that is, the matrix does have an inverse
      int n = m.Length;
      double[][] result = MatCreate(n, n); // make a copy
      for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
          result[i][j] = m[i][j];

      double[][] lum; // combined lower & upper
      int[] perm;  // out parameter
      MatDecompose(m, out lum, out perm);  // ignore return

      double[] b = new double[n];
      for (int i = 0; i < n; ++i)
      {
        for (int j = 0; j < n; ++j)
          if (i == perm[j])
            b[j] = 1.0;
          else
            b[j] = 0.0;

        double[] x = Reduce(lum, b); // 
        for (int j = 0; j < n; ++j)
          result[j][i] = x[j];
      }
      return result;
    }

    static int MatDecompose(double[][] m,
      out double[][] lum, out int[] perm)
    {
      // Crout's LU decomposition for matrix determinant
      // and inverse.
      // stores combined lower & upper in lum[][]
      // stores row permuations into perm[]
      // returns +1 or -1 according to even or odd number
      // of row permutations.
      // lower gets dummy 1.0s on diagonal (0.0s above)
      // upper gets lum values on diagonal (0.0s below)

      // even (+1) or odd (-1) row permutatuions
      int toggle = +1;
      int n = m.Length;

      // make a copy of m[][] into result lu[][]
      lum = MatCreate(n, n);
      for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
          lum[i][j] = m[i][j];

      // make perm[]
      perm = new int[n];
      for (int i = 0; i < n; ++i)
        perm[i] = i;

      for (int j = 0; j < n - 1; ++j) // note n-1 
      {
        double max = Math.Abs(lum[j][j]);
        int piv = j;

        for (int i = j + 1; i < n; ++i) // pivot index
        {
          double xij = Math.Abs(lum[i][j]);
          if (xij > max)
          {
            max = xij;
            piv = i;
          }
        } // i

        if (piv != j)
        {
          double[] tmp = lum[piv]; // swap rows j, piv
          lum[piv] = lum[j];
          lum[j] = tmp;

          int t = perm[piv]; // swap perm elements
          perm[piv] = perm[j];
          perm[j] = t;

          toggle = -toggle;
        }

        double xjj = lum[j][j];
        if (xjj != 0.0)
        {
          for (int i = j + 1; i < n; ++i)
          {
            double xij = lum[i][j] / xjj;
            lum[i][j] = xij;
            for (int k = j + 1; k < n; ++k)
              lum[i][k] -= xij * lum[j][k];
          }
        }

      } // j

      return toggle;  // for determinant
    } // MatDecompose

    static double[] Reduce(double[][] luMatrix,
      double[] b) // helper
    {
      int n = luMatrix.Length;
      double[] x = new double[n];
      //b.CopyTo(x, 0);
      for (int i = 0; i < n; ++i)
        x[i] = b[i];

      for (int i = 1; i < n; ++i)
      {
        double sum = x[i];
        for (int j = 0; j < i; ++j)
          sum -= luMatrix[i][j] * x[j];
        x[i] = sum;
      }

      x[n - 1] /= luMatrix[n - 1][n - 1];
      for (int i = n - 2; i >= 0; --i)
      {
        double sum = x[i];
        for (int j = i + 1; j < n; ++j)
          sum -= luMatrix[i][j] * x[j];
        x[i] = sum / luMatrix[i][i];
      }

      return x;
    } // Reduce

    //public static double MatDeterminant(double[][] m)
    //{
    //  double[][] lum;
    //  int[] perm;
    //  double result = MatDecompose(m, out lum, out perm);
    //  for (int i = 0; i < lum.Length; ++i)
    //    result *= lum[i][i];
    //  return result;
    //}

    // -------

    static int NumNonCommentLines(string fn,
      string comment)
    {
      int ct = 0;
      string line = "";
      FileStream ifs = new FileStream(fn,
        FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
        if (line.StartsWith(comment) == false)
          ++ct;
      sr.Close(); ifs.Close();
      return ct;
    }

    public static double[][] MatLoad(string fn,
      int[] usecols, char sep, string comment)
    {
      // count number of non-comment lines
      int nRows = NumNonCommentLines(fn, comment);
      int nCols = usecols.Length;
      double[][] result = MatCreate(nRows, nCols);
      string line = "";
      string[] tokens = null;
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);

      int i = 0;
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        tokens = line.Split(sep);
        for (int j = 0; j < nCols; ++j)
        {
          int k = usecols[j];  // into tokens
          result[i][j] = double.Parse(tokens[k]);
        }
        ++i;
      }
      sr.Close(); ifs.Close();
      return result;
    }

    //public static double[][] MatMakeDesign(double[][] m)
    //{
    //  // add a leading column of 1s to create a design matrix
    //  int nRows = m.Length; int nCols = m[0].Length;
    //  double[][] result = MatCreate(nRows, nCols + 1);
    //  for (int i = 0; i < nRows; ++i)
    //    result[i][0] = 1.0;

    //  for (int i = 0; i < nRows; ++i)
    //  {
    //    for (int j = 0; j < nCols; ++j)
    //    {
    //      result[i][j + 1] = m[i][j];
    //    }
    //  }
    //  return result;
    //}

    public static void MatShow(double[][] m,
      int dec, int wid)
    {
      for (int i = 0; i < m.Length; ++i)
      {
        for (int j = 0; j < m[0].Length; ++j)
        {
          double v = m[i][j];
          if (Math.Abs(v) < 1.0e-8) v = 0.0; // hack
          Console.Write(v.ToString("F" +
            dec).PadLeft(wid));
        }
        Console.WriteLine("");
      }
    }

    //public static void VecShow(int[] vec, int wid)
    //{
    //  for (int i = 0; i < vec.Length; ++i)
    //    Console.Write(vec[i].ToString().PadLeft(wid));
    //  Console.WriteLine("");
    //}

    public static void VecShow(double[] vec,
      int dec, int wid, bool newLine)
    {
      for (int i = 0; i < vec.Length; ++i)
      {
        double x = vec[i];
        if (Math.Abs(x) < 1.0e-8) x = 0.0;  // hack
        Console.Write(x.ToString("F" +
          dec).PadLeft(wid));
      }
      if (newLine == true)
        Console.WriteLine("");
    }

  } // class Utils

} // ns