﻿using System;
using System.IO;

namespace GaussianProcessRegression
{
  internal class GPRProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nGaussian Process Regression ");

      // 1. Load data
      Console.WriteLine("\nLoading data from file ");
      string trainFile =
        "..\\..\\..\\Data\\synthetic_train_200.txt";
      double[][] trainX = Utils.MatLoad(trainFile,
        new int[] { 0, 1, 2, 3, 4, 5 }, ',', "#");
      double[] trainY = Utils.MatToVec(Utils.MatLoad(trainFile,
        new int[] { 6 }, ',', "#"));

      string testFile =
        "..\\..\\..\\Data\\synthetic_test_40.txt";
      double[][] testX = Utils.MatLoad(testFile,
        new int[] { 0, 1, 2, 3, 4, 5 }, ',', "#");
      double[] testY = Utils.MatToVec(Utils.MatLoad(testFile,
        new int[] { 6 }, ',', "#"));
      Console.WriteLine("Done ");

      Console.WriteLine("\nFirst four X predictors: ");
      for (int i = 0; i < 4; ++i)
        Utils.VecShow(trainX[i], 4, 8);
      Console.WriteLine("\nFirst four y targets: ");
      for (int i = 0; i < 4; ++i)
        Console.WriteLine(trainY[i].ToString("F4").PadLeft(8));

      // 2. explore alpha, theta, lenScale hyperparameters
      Console.WriteLine("\nExploring hyperparameters ");
      double[] alphas = new double[] { 0.01, 0.05 };
      double[] thetas = new double[] { 0.50, 1.00 };
      double[] lenScales = new double[] { 0.60, 1.00, 2.00 };

      double trainRMSE = 0.0;
      double testRMSE = 0.0;
      double trainAcc = 0.0;
      double testAcc = 0.0;

      foreach (double t in thetas)
      {
        foreach (double s in lenScales)
        {
          foreach (double a in alphas)
          {
            GPR gprModel = new GPR(t, s, a, trainX, trainY);
            trainRMSE = gprModel.RootMSE(trainX, trainY);
            testRMSE = gprModel.RootMSE(testX, testY);

            trainAcc = gprModel.Accuracy(trainX,
              trainY, 0.10);
            testAcc = gprModel.Accuracy(testX,
              testY, 0.10);

            Console.Write(" theta = " +
              t.ToString("F2"));
            Console.Write("  lenScale = " +
              s.ToString("F2"));
            Console.Write("  alpha = " +
              a.ToString("F2"));
            Console.Write("  |  train RMSE = " +
              trainRMSE.ToString("F4"));
            Console.Write("  test RMSE = " +
              testRMSE.ToString("F4"));
            Console.Write("  train acc = " +
              trainAcc.ToString("F4"));
            Console.Write("  test acc = " +
              testAcc.ToString("F4"));
            Console.WriteLine("");
          }
        }
      }
      
      // 3. create and train model
      Console.WriteLine("\nCreating and training GPR model ");
      double theta = 0.50;    // "constant kernel"
      double lenScale = 2.0;  // RBF parameter
      double alpha = 0.01;    // noise
      Console.WriteLine("Setting theta = " +
        theta.ToString("F2") +
        ", lenScale = " +
        lenScale.ToString("F2") + 
        ", alpha = " + alpha.ToString("F2"));

      GPR model = new GPR(theta, lenScale, alpha,
        trainX, trainY);  // create and train

      // 4. evaluate model
      Console.WriteLine("\nEvaluate model acc" +
        " (within 0.10 of true)");
      trainAcc = model.Accuracy(trainX, trainY, 0.10);
      testAcc = model.Accuracy(testX, testY, 0.10);
      Console.WriteLine("Train acc = " +
        trainAcc.ToString("F4"));
      Console.WriteLine("Test acc = " +
        testAcc.ToString("F4"));

      // 5. use model for previously unseen data
      Console.WriteLine("\nPredicting for (0.1, 0.2, 0.3," + 
        " 0.4, 0.5, 0.6) ");
      double[][] X = Utils.MatCreate(1, 6);
      X[0] = new double[] { 0.1, 0.2, 0.3, 0.4, 0.5, 0.6 };

      double[][] results = model.Predict(X);

      Console.WriteLine("Predicted y value: ");
      double[] means = results[0];
      Utils.VecShow(means, 4, 8);
      Console.WriteLine("Predicted std: ");
      double[] stds = results[1];
      Utils.VecShow(stds, 4, 8);

      Console.WriteLine("\nEnd GPR demo ");
      Console.ReadLine();
    } // Main()
  } // Program

  public class GPR
  {
    public double theta = 0.0;
    public double lenScale = 0.0;
    public double noise = 0.0;
    public double[][] trainX;
    public double[] trainY;
    public double[][] invCovarMat;

    public GPR(double theta, double lenScale,
      double noise, double[][] trainX, double[] trainY)
    {
      this.theta = theta;
      this.lenScale = lenScale;
      this.noise = noise;
      this.trainX = trainX;
      this.trainY = trainY;

      double[][] covarMat = 
        this.ComputeCovarMat(this.trainX, this.trainX);
      int n = covarMat.Length;
      for (int i = 0; i < n; ++i)
        covarMat[i][i] += this.noise;  // aka alpha, lambda
      this.invCovarMat = Utils.MatInverse(covarMat);
    } // ctor()

    public double[][] ComputeCovarMat(double[][] X1,
      double[][] X2)
    {
      int n1 = X1.Length; int n2 = X2.Length;
      double[][] result = Utils.MatCreate(n1, n2);
      for (int i = 0; i < n1; ++i)
        for (int j = 0; j < n2; ++j)
          result[i][j] = this.KernelFunc(X1[i], X2[j]);
      return result;
    }

    public double KernelFunc(double[] x1, double[] x2)
    {
      // constant * RBF
      int dim = x1.Length;
      double sum = 0.0;  // Euclidean distance squared
      for (int i = 0; i < dim; ++i)
        sum += (x1[i] - x2[i]) * (x1[i] - x2[i]);
      double term = 
        -1 / (2 * (this.lenScale * this.lenScale));
      return this.theta * Math.Exp(sum * term);
    }

    public double[][] Predict(double[][] X)
    {
      // trainx = (200,6)
      // trainY = (200)
      // this.invCovarMat = (200,200)
      // X = (n,6)

      double[][] result = new double[2][];  // means, stds

      // X = to predict, X* = train X, Y* = trainY as matrix
      // means = K(X,X*)  *  inv(K(X*,X*))   *  Y*

      double[][] a = 
        this.ComputeCovarMat(X, this.trainX);  // (n,200)
      double[][] b = 
        Utils.MatProduct(a, this.invCovarMat); // (n,200)
      double[][] c = Utils.VecToMat(this.trainY,
        trainY.Length, 1);  //  (200,1)
      double[][] d = Utils.MatProduct(b, c);  // (n,1)
      double[] means = Utils.MatToVec(d);     // (n)

      // sigmas matrix = K(X,X) - [ a * invCoverMat * (a)T ]
      double[][] e = this.ComputeCovarMat(X, X);
      double[][] f = Utils.MatProduct(a, this.invCovarMat);
      double[][] g = 
        Utils.MatProduct(f, Utils.MatTranspose(a));
      double[][] h = Utils.MatDifference(e, g);

      int n = h.Length;
      double[] stds = new double[n];  // sqrt of diag elements
      for (int i = 0; i < n; ++i)
        stds[i] = Math.Sqrt(h[i][i]);

      result[0] = means;
      result[1] = stds;

      return result;
    } // Predict()

    public double Accuracy(double[][] dataX, double[] dataY,
      double pctClose)
    {
      int numCorrect = 0; int numWrong = 0;
      // get all predictions
      double[][] results = this.Predict(dataX);

      double[] y_preds = results[0];
      for (int i = 0; i < y_preds.Length; ++i)
      {
        if (Math.Abs(y_preds[i] - dataY[i])
          < Math.Abs(pctClose * dataY[i]))
          numCorrect += 1;
        else
          numWrong += 1;
      }
      return (numCorrect * 1.0) / (numCorrect + numWrong);
    }

    public double RootMSE(double[][] dataX, double[] dataY)
    {
      // get all predictions
      double[][] results = this.Predict(dataX);
      double[] y_preds = results[0];  // no need stds

      double sumSquaredErr = 0.0;
      for (int i = 0; i < y_preds.Length; ++i)
        sumSquaredErr += (y_preds[i] - dataY[i]) *
          (y_preds[i] - dataY[i]);

      return Math.Sqrt(sumSquaredErr / dataY.Length);
    }

  } // class GPR

  public class Utils
  {
    //public static double[][] VecToMat(double[] vec)
    //{
    //  // vector to row vec/matrix
    //  double[][] result = MatCreate(vec.Length, 1);
    //  for (int i = 0; i < vec.Length; ++i)
    //    result[i][0] = vec[i];
    //  return result;
    //}

    public static double[][] VecToMat(double[] vec,
      int nRows, int nCols)
    {
      double[][] result = MatCreate(nRows, nCols);
      int k = 0;
      for (int i = 0; i < nRows; ++i)
        for (int j = 0; j < nCols; ++j)
          result[i][j] = vec[k++];
      return result;
    }

    public static double[][] MatCreate(int rows,
      int cols)
    {
      double[][] result = new double[rows][];
      for (int i = 0; i < rows; ++i)
        result[i] = new double[cols];
      return result;
    }

    public static double[] MatToVec(double[][] m)
    {
      int rows = m.Length;
      int cols = m[0].Length;
      double[] result = new double[rows * cols];
      int k = 0;
      for (int i = 0; i < rows; ++i)
        for (int j = 0; j < cols; ++j)
        {
          result[k++] = m[i][j];
        }
      return result;
    }

    public static double[][] MatProduct(double[][] matA,
      double[][] matB)
    {
      int aRows = matA.Length;
      int aCols = matA[0].Length;
      int bRows = matB.Length;
      int bCols = matB[0].Length;
      if (aCols != bRows)
        throw new Exception("Non-conformable matrices");

      double[][] result = MatCreate(aRows, bCols);

      for (int i = 0; i < aRows; ++i) // each row of A
        for (int j = 0; j < bCols; ++j) // each col of B
          for (int k = 0; k < aCols; ++k)
            result[i][j] += matA[i][k] * matB[k][j];

      return result;
    }

    //public static double[] VecMatProd(double[] v,
    //  double[][] m)
    //{
    //  // ex: 
    //  int nRows = m.Length;
    //  int nCols = m[0].Length;
    //  int n = v.Length;
    //  if (n != nCols)
    //    throw new Exception("non-comform in VecMatProd");

    //  double[] result = new double[n];
    //  for (int i = 0; i < n; ++i)
    //  {
    //    for (int j = 0; j < nCols; ++j)
    //    {
    //      result[i] += v[j] * m[i][j];
    //    }
    //  }
    //  return result;
    //}

    //public static double[] MatVecProd(double[][] mat,
    //  double[] vec)
    //{
    //  int nRows = mat.Length; int nCols = mat[0].Length;
    //  if (vec.Length != nCols)
    //    throw new Exception("Non-conforme MatVecProd() ");
    //  double[] result = new double[nRows];
    //  for (int i = 0; i < nRows; ++i)
    //  {
    //    for (int j = 0; j < nCols; ++j)
    //    {
    //      result[i] += mat[i][j] * vec[j];
    //    }
    //  }
    //  return result;
    //}

    public static double[][] MatTranspose(double[][] m)
    {
      int nr = m.Length;
      int nc = m[0].Length;
      double[][] result = MatCreate(nc, nr);  // note
      for (int i = 0; i < nr; ++i)
        for (int j = 0; j < nc; ++j)
          result[j][i] = m[i][j];
      return result;
    }

    public static double[][] MatDifference(double[][] matA,
      double[][] matB)
    {
      int nr = matA.Length;
      int nc = matA[0].Length;
      double[][] result = MatCreate(nc, nr);  // note
      for (int i = 0; i < nr; ++i)
        for (int j = 0; j < nc; ++j)
          result[j][i] = matA[i][j] - matB[i][j];
      return result;
    }

    // -------

    public static double[][] MatInverse(double[][] m)
    {
      // assumes determinant is not 0
      // that is, the matrix does have an inverse
      int n = m.Length;
      double[][] result = MatCreate(n, n); // make a copy
      for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
          result[i][j] = m[i][j];

      double[][] lum; // combined lower & upper
      int[] perm;  // out parameter
      MatDecompose(m, out lum, out perm);  // ignore return

      double[] b = new double[n];
      for (int i = 0; i < n; ++i)
      {
        for (int j = 0; j < n; ++j)
          if (i == perm[j])
            b[j] = 1.0;
          else
            b[j] = 0.0;

        double[] x = Reduce(lum, b); // 
        for (int j = 0; j < n; ++j)
          result[j][i] = x[j];
      }
      return result;
    }

    static int MatDecompose(double[][] m,
      out double[][] lum, out int[] perm)
    {
      // Crout's LU decomposition for matrix determinant
      // and inverse.
      // stores combined lower & upper in lum[][]
      // stores row permuations into perm[]
      // returns +1 or -1 according to even or odd number
      // of row permutations.
      // lower gets dummy 1.0s on diagonal (0.0s above)
      // upper gets lum values on diagonal (0.0s below)

      // even (+1) or odd (-1) row permutatuions
      int toggle = +1;
      int n = m.Length;

      // make a copy of m[][] into result lu[][]
      lum = MatCreate(n, n);
      for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
          lum[i][j] = m[i][j];

      // make perm[]
      perm = new int[n];
      for (int i = 0; i < n; ++i)
        perm[i] = i;

      for (int j = 0; j < n - 1; ++j) // note n-1 
      {
        double max = Math.Abs(lum[j][j]);
        int piv = j;

        for (int i = j + 1; i < n; ++i) // pivot index
        {
          double xij = Math.Abs(lum[i][j]);
          if (xij > max)
          {
            max = xij;
            piv = i;
          }
        } // i

        if (piv != j)
        {
          double[] tmp = lum[piv]; // swap rows j, piv
          lum[piv] = lum[j];
          lum[j] = tmp;

          int t = perm[piv]; // swap perm elements
          perm[piv] = perm[j];
          perm[j] = t;

          toggle = -toggle;
        }

        double xjj = lum[j][j];
        if (xjj != 0.0)
        {
          for (int i = j + 1; i < n; ++i)
          {
            double xij = lum[i][j] / xjj;
            lum[i][j] = xij;
            for (int k = j + 1; k < n; ++k)
              lum[i][k] -= xij * lum[j][k];
          }
        }

      } // j

      return toggle;  // for determinant
    } // MatDecompose

    static double[] Reduce(double[][] luMatrix,
      double[] b) // helper
    {
      int n = luMatrix.Length;
      double[] x = new double[n];
      //b.CopyTo(x, 0);
      for (int i = 0; i < n; ++i)
        x[i] = b[i];

      for (int i = 1; i < n; ++i)
      {
        double sum = x[i];
        for (int j = 0; j < i; ++j)
          sum -= luMatrix[i][j] * x[j];
        x[i] = sum;
      }

      x[n - 1] /= luMatrix[n - 1][n - 1];
      for (int i = n - 2; i >= 0; --i)
      {
        double sum = x[i];
        for (int j = i + 1; j < n; ++j)
          sum -= luMatrix[i][j] * x[j];
        x[i] = sum / luMatrix[i][i];
      }

      return x;
    } // Reduce

    //public static double MatDeterminant(double[][] m)
    //{
    //  double[][] lum;
    //  int[] perm;
    //  double result = MatDecompose(m, out lum, out perm);
    //  for (int i = 0; i < lum.Length; ++i)
    //    result *= lum[i][i];
    //  return result;
    //}

    // -------

    static int NumNonCommentLines(string fn,
      string comment)
    {
      int ct = 0;
      string line = "";
      FileStream ifs = new FileStream(fn,
        FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
        if (line.StartsWith(comment) == false)
          ++ct;
      sr.Close(); ifs.Close();
      return ct;
    }

    public static double[][] MatLoad(string fn,
      int[] usecols, char sep, string comment)
    {
      // count number of non-comment lines
      int nRows = NumNonCommentLines(fn, comment);

      int nCols = usecols.Length;
      double[][] result = MatCreate(nRows, nCols);
      string line = "";
      string[] tokens = null;
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);

      int i = 0;
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        tokens = line.Split(sep);
        for (int j = 0; j < nCols; ++j)
        {
          int k = usecols[j];  // into tokens
          result[i][j] = double.Parse(tokens[k]);
        }
        ++i;
      }
      sr.Close(); ifs.Close();
      return result;
    }

    public static void MatShow(double[][] m,
      int dec, int wid)
    {
      for (int i = 0; i < m.Length; ++i)
      {
        for (int j = 0; j < m[0].Length; ++j)
        {
          double v = m[i][j];
          if (Math.Abs(v) < 1.0e-8) v = 0.0; // hack
          Console.Write(v.ToString("F" +
            dec).PadLeft(wid));
        }
        Console.WriteLine("");
      }
    }

    public static void VecShow(int[] vec, int wid)
    {
      for (int i = 0; i < vec.Length; ++i)
        Console.Write(vec[i].ToString().PadLeft(wid));
      Console.WriteLine("");
    }

    public static void VecShow(double[] vec,
      int dec, int wid)
    {
      for (int i = 0; i < vec.Length; ++i)
      {
        double x = vec[i];
        if (Math.Abs(x) < 1.0e-8) x = 0.0;  // hack
        Console.Write(x.ToString("F" +
          dec).PadLeft(wid));
      }
      Console.WriteLine("");
    }

  } // class Utils
} // ns