﻿using System;
using System.Collections.Generic;
using System.IO;

namespace DecisionTreeRegression
{
  internal class DecisionTreeRegressionProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nDecision tree regression C# ");
      Console.WriteLine("Predict income from sex," +
        " age, State, political leaning ");

      // ----------------------------------------------------

      Console.WriteLine("\nLoading data from file ");
      string trainFile =
        "..\\..\\..\\Data\\people_train_tree_10.txt";
      // sex, age, State, income, politics
      double[][] trainX = Utils.MatLoad(trainFile,
        new int[] { 0, 1, 2, 4 }, ',', "#");
      double[] trainY = 
        Utils.MatToVec(Utils.MatLoad(trainFile,
        new int[] { 3 }, ',', "#"));

      string testFile =
        "..\\..\\..\\Data\\people_test_tree_5.txt";
      double[][] testX = Utils.MatLoad(testFile,
        new int[] { 0, 1, 2, 4 }, ',', "#");
      double[] testY =
        Utils.MatToVec(Utils.MatLoad(testFile,
        new int[] { 3 }, ',', "#"));
      Console.WriteLine("Done ");

      Console.WriteLine("\nFirst three X data: ");
      for (int i = 0; i < 3; ++i)
        Utils.VecShow(trainX[i], 4, 9, true);

      Console.WriteLine("\nFirst three target Y : ");
      for (int i = 0; i < 3; ++i)
        Console.WriteLine(trainY[i].ToString("F4"));

      int treeSize = 7;  // power of 2 minus 1 recommended
      Console.WriteLine("\nCreating tree, size = " + treeSize);
      // sex, age, State, politics
      string[] columnKind = 
        new string[] { "C", "N", "C", "C" };
      //string[] columnNames =
      //  new string[] { "sex", "age", "State", "politics" };

      DecisionTree dt = new DecisionTree(treeSize, columnKind);
      dt.BuildTree(trainX, trainY);
      Console.WriteLine("Done ");

      Console.WriteLine("\nTree snapshot: ");
      //dt.ShowTree();  // show all nodes in tree
      dt.ShowNode(0);
      dt.ShowNode(1);
      dt.ShowNode(2);
      //dt.ShowNode(10);
      //dt.ShowNode(30);

      Console.WriteLine("\nComputing model accuracy:");
      double trainAcc = dt.Accuracy(trainX, trainY, 0.10);
      Console.WriteLine("Train data accuracy = " +
        trainAcc.ToString("F4"));

      double testAcc = dt.Accuracy(testX, testY, 0.10);
      Console.WriteLine("Test data accuracy = " +
        testAcc.ToString("F4"));

      Console.WriteLine("\nPredicting for male, 34," +
        " Oklahoma, moderate ");
      double[] x = new double[] { 0, 0.34, 2, 1 };
      double predY = dt.Predict(x, verbose: true);
      // Console.WriteLine("Predicted income = " +
      // predY.ToString("F5"));

      Console.WriteLine("\nEnd demo ");
      Console.ReadLine();
    } // Main

  } // Program

  // --------------------------------------------------------

  class DecisionTree
  {
    public int numNodes;
    public int numClasses;
    public List<Node> tree;
    public string[] columnKind;
    //public string[] columnNames;

    // ---------- nested classes

    public class Node
    {
      public int nodeID;
      public List<int> rows;  // source data rows
      public int splitCol;
      public double splitVal;
      public double[] targetValues;
      public double predictedY;  // avg of targets
    }

    public class SplitInfo  // helper struc
    {
      public int splitCol;
      public double splitVal;
      public List<int> lessRows;
      public List<int> greaterRows;
    }

    // ----------

    public DecisionTree(int numNodes, string[] colKind)
    {
      this.numNodes = numNodes;
      this.tree = new List<Node>();
      for (int i = 0; i < numNodes; ++i)
        this.tree.Add(new Node());
      this.columnKind = colKind;  // by ref
      //this.columnNames = colNames;
    } // ctor

    // ------------------------------------------------------

    public void BuildTree(double[][] trainX,
      double[] trainY)
    {
      // prep the list and the root node
      int n = trainX.Length;

      List<int> allRows = new List<int>();
      for (int i = 0; i < n; ++i)
        allRows.Add(i);

      this.tree[0].rows = new List<int>(allRows);

      for (int i = 0; i < this.numNodes; ++i)
      {
        this.tree[i].nodeID = i;

        SplitInfo si = GetSplitInfo(trainX, trainY,
          this.tree[i].rows, this.columnKind);
        // SplitInfo si = GetSplitInfoModular(trainX,
        //   trainY, this.tree[i].rows, this.columnKind);
        // SplitInfo si = GetSplitInfoForNode(i, trainX,
        //   trainY, this.tree[i].rows, this.columnKind);
        this.tree[i].splitCol = si.splitCol;
        this.tree[i].splitVal = si.splitVal;

        //Utils.ListShow(this.tree[i].rows, 4);
        //Console.ReadLine();

        this.tree[i].targetValues = 
          GetTargets(trainY, this.tree[i].rows);
        this.tree[i].predictedY = 
          Average(tree[i].targetValues);

        int leftChild = (2 * i) + 1;
        int rightChild = (2 * i) + 2;

        if (leftChild < numNodes)
          tree[leftChild].rows =
            new List<int>(si.lessRows);
        if (rightChild < numNodes)
          tree[rightChild].rows =
            new List<int>(si.greaterRows);
      } // i

    } // BuildTree()

    // ------------------------------------------------------

    private static double[] GetTargets(double[] trainY,
      List<int> rows)
    {
      int n = rows.Count;
      double[] result = new double[n];
 
      for (int i = 0; i < n; ++i)
      {
        int r = rows[i];
        double target = trainY[r];
        result[i] = target;
      }
      return result;
    }

    // ------------------------------------------------------
    
    private static double Average(double[] targets)
    {
      int n = targets.Length;
      double sum = 0.0;
      for (int i = 0; i < n; ++i)
        sum += targets[i];
      return sum / n;
    }

    // ------------------------------------------------------

    //private static SplitInfo GetSplitInfoForColumn(int col,
    //  double[][] trainX, double[] trainY, List<int> rows,
    //  string[] colKind)
    //{
    //  // given a set of parent rows and a specified col
    //  // index, find the col and value, and less-rows
    //  // and greater-rows of the 
    //  // partition that gives lowest resulting mean
    //  // variance/variability
    //  int nCols = trainX[0].Length;
    //  SplitInfo result = new SplitInfo();

    //  double bestSplitVal = 0.0;
    //  double bestVariability = double.MaxValue;
    //  List<int> bestLessRows = new List<int>();
    //  List<int> bestGreaterRows = new List<int>();

    //  string kind = colKind[col];  // "N"um or "C"at

    //  foreach (int i in rows)
    //  {
    //    double splitVal = trainX[i][col];
    //    List<int> lessRows = new List<int>();
    //    List<int> greaterRows = new List<int>();
    //    foreach (int ii in rows)  // walk column
    //    {
    //      if (kind == "N")  // numeric
    //      {
    //        if (trainX[ii][col] < splitVal)
    //          lessRows.Add(ii);
    //        else
    //          greaterRows.Add(ii);
    //      }
    //      else if (kind == "C")// categorical
    //      {
    //        if ((int)trainX[ii][col] != (int)splitVal)
    //          lessRows.Add(ii);
    //        else
    //          greaterRows.Add(ii);
    //      }
    //    } // ii

    //    double meanVariability =
    //      MeanVariability(trainY,
    //      lessRows, greaterRows);
    //    if (meanVariability < bestVariability)
    //    {
    //      bestVariability = meanVariability;
    //      //bestSplitCol = j;
    //      bestSplitVal = splitVal;

    //      bestLessRows =
    //        new List<int>(lessRows);
    //      bestGreaterRows =
    //        new List<int>(greaterRows);
    //    }

    //  } // i

    //  result.splitCol = col;
    //  result.splitVal = bestSplitVal;
    //  result.lessRows =
    //    new List<int>(bestLessRows);
    //  result.greaterRows =
    //    new List<int>(bestGreaterRows);

    //  return result;
    //} // GetSplitInfoForColumn


    // ------------------------------------------------------

    //private static SplitInfo GetSplitInfoModular(double[][]
    //  trainX, double[] trainY, List<int> rows,
    //  string[] colKind)
    //{
    //  int nCols = trainX[0].Length;
    //  SplitInfo result = new SplitInfo();
    //  double bestVariability = double.MaxValue;

    //  foreach (int i in rows)
    //  {
    //    for (int j = 0; j < nCols; ++j)
    //    {
    //      SplitInfo si = GetSplitInfoForColumn(j, trainX,
    //        trainY, rows, colKind);
    //      double meanVariability = MeanVariability(trainY,
    //        si.lessRows, si.greaterRows);
    //      if (meanVariability < bestVariability)
    //      {
    //        bestVariability = meanVariability;
    //        result = si;
    //      }
    //    } // j
    //  } // i

    //  return result;
    //} // GetSplitInfoModular

    // ------------------------------------------------------

    //private static SplitInfo GetSplitInfoForNode(int nodeIdx,
    //  double[][] trainX, double[] trainY, List<int> rows,
    //  string[] colKind)
    //{
    //  // force examination of every column at least once
    //  int nCols = trainX[0].Length;
    //  SplitInfo result = new SplitInfo();
    //  double bestVariability = double.MaxValue;

    //  if (nodeIdx < nCols) // use specific column
    //  {
    //    int useCol = nodeIdx;  // node 0 use col 0, etc.
    //    foreach (int i in rows)
    //    {
    //      SplitInfo si = GetSplitInfoForColumn(useCol,
    //         trainX, trainY, rows, colKind);
    //      double meanVariability = MeanVariability(trainY,
    //        si.lessRows, si.greaterRows);
    //      if (meanVariability < bestVariability)
    //      {
    //        bestVariability = meanVariability;
    //        result = si;
    //      }
    //    } // i
    //  }
    //  else // scan all columns
    //  {
    //    foreach (int i in rows)
    //    {
    //      for (int j = 0; j < nCols; ++j)
    //      {
    //        SplitInfo si = GetSplitInfoForColumn(j,
    //          trainX, trainY, rows, colKind);
    //        double meanVariability = 
    //          MeanVariability(trainY, si.lessRows,
    //          si.greaterRows);
    //        if (meanVariability < bestVariability)
    //        {
    //          bestVariability = meanVariability;
    //          result = si;
    //        }
    //      } // j
    //    } // i
    //  }

    //  return result;
    //} // GetSplitInfoForNode

    // ------------------------------------------------------

    private static SplitInfo GetSplitInfo(double[][]
      trainX, double[] trainY, List<int> rows,
      string[] colKind)
    {
      // given a set of parent rows, find the col and
      // value, and less-rows and greater-rows of
      // partition that gives lowest resulting mean
      // variance/variability
      int nCols = trainX[0].Length;
      SplitInfo result = new SplitInfo();

      int bestSplitCol = 0;
      double bestSplitVal = 0.0;
      double bestVariability = double.MaxValue;
      List<int> bestLessRows = new List<int>();
      List<int> bestGreaterRows = new List<int>();

      foreach (int i in rows)
      {
        for (int j = 0; j < nCols; ++j)
        {
          string kind = colKind[j];  // "N"um or "C"at

          double splitVal = trainX[i][j];
          List<int> lessRows = new List<int>();
          List<int> greaterRows = new List<int>();
          foreach (int ii in rows)  // walk column
          {
            if (kind == "N")  // numeric
            {
              if (trainX[ii][j] < splitVal)
                lessRows.Add(ii);
              else
                greaterRows.Add(ii);
            }
            else if (kind == "C")// categorical
            {
              if ((int)trainX[ii][j] != (int)splitVal)
                lessRows.Add(ii);
              else
                greaterRows.Add(ii);
            }

          } // ii

          double meanVariability =
            MeanVariability(trainY,
            lessRows, greaterRows);
          if (meanVariability < bestVariability)
          {
            bestVariability = meanVariability;
            bestSplitCol = j;
            bestSplitVal = splitVal;

            bestLessRows =
              new List<int>(lessRows);
            bestGreaterRows =
              new List<int>(greaterRows);
          }

        } // j
      } // i

      result.splitCol = bestSplitCol;
      result.splitVal = bestSplitVal;
      result.lessRows =
        new List<int>(bestLessRows);
      result.greaterRows =
        new List<int>(bestGreaterRows);

      return result;
    }

    // ------------------------------------------------------

    private static double Variability(double[]
      trainY, List<int> rows)
    {
      // lower variability better
      int n = rows.Count;
      if (n == 0) return 0.0;  // FIX THIS

      double sum = 0.0;  // compute mean
      for (int i = 0; i < rows.Count; ++i)
      {
        int idx = rows[i];
        double target = trainY[idx];
        sum += target;
      }
      double mean = sum / n;

      // use mean to compute variance
      for (int i = 0; i < rows.Count; ++i)
      {
        int idx = rows[i];  // pts into refY
        double target = trainY[idx];
        sum += (target - mean) * (target - mean);
      }
      return sum / n;  // variance
    }

    // ------------------------------------------------------

    private static double MeanVariability(double[] trainY,
      List<int> rows1, List<int> rows2)
    {
      // weighted by number items in rows
      if (rows1.Count == 0 && rows2.Count == 0)
        return 0.0;  // FIX

      // 0.0 if either rows Count is 0:
      double variability1 = Variability(trainY, rows1);
      double variability2 = Variability(trainY, rows2);
      int count1 = rows1.Count;
      int count2 = rows2.Count;
      double wt1 = (count1 * 1.0) / (count1 + count2);
      double wt2 = (count2 * 1.0) / (count1 + count2);
      double result = (wt1 * variability1) +
        (wt2 * variability2);
      return result;
    }

    // ------------------------------------------------------

    public void ShowTree()  // show all nodes in tree
    {
      for (int i = 0; i < this.numNodes; ++i)
        ShowNode(i);
    }

    // ------------------------------------------------------

    public void ShowNode(int nodeID)
    {
      Console.WriteLine("\n==========");
      Console.WriteLine("Node ID: " +
        this.tree[nodeID].nodeID);

      Console.WriteLine("\nSource rows: ");
      for (int i = 0; i < this.tree[nodeID].rows.Count;
        ++i)
      {
        if (i > 0 && i % 10 == 0) Console.WriteLine("");
        Console.Write(this.tree[nodeID].rows[i].
          ToString().PadLeft(4) + " ");
      }
      Console.WriteLine("");

      Console.WriteLine("Node target values: ");
      for (int i = 0; i < this.tree[nodeID].
        targetValues.Length; ++i)
      {
        if (i > 0 && i % 10 == 0) Console.WriteLine("");
        Console.Write(this.tree[nodeID].
          targetValues[i].ToString("F4").PadLeft(8));
      }
      Console.WriteLine("");
      
      Console.WriteLine("Node predicted y: " +
        this.tree[nodeID].predictedY.ToString("F4"));

      Console.WriteLine("\nNode split column: " +
        this.tree[nodeID].splitCol);

      //Console.Write("Node split column: " +
      //  this.tree[nodeID].splitCol);
      //Console.WriteLine(" (" +
      //  this.columnNames[this.tree[nodeID].splitCol] +
      //  ")");

      Console.WriteLine("Node split value: " +
        this.tree[nodeID].splitVal.ToString("F2"));
      Console.WriteLine("==========");
    }

    // ------------------------------------------------------

    public double Predict(double[] x, bool verbose)
    {
      bool vb = verbose;
      double result = -1.0;
      int currNodeID = 0;
      int newNodeID = 0;
      string rule = "IF (*)";  // if any  . . 
      while (true)
      {
        if (this.tree[currNodeID].rows.Count == 0)
          break; // at an empty node

        if (vb) Console.WriteLine("\ncurr node id = " +
          currNodeID);

        int sc = this.tree[currNodeID].splitCol;
        //string sn = this.columnNames[sc];

        string scKind = this.columnKind[sc];  // "N" or "C"
        if (vb) Console.WriteLine("Column kind = " +
          scKind);

        // --------------------------------------------------

        if (scKind == "N")  // a Numeric column
        {
          double sv = this.tree[currNodeID].splitVal;
          double v = x[sc];
          if (vb) Console.WriteLine("Comparing " + sv +
            " in column " + sc + " with " + v);

          if (v < sv)
          {
            newNodeID = (2 * currNodeID) + 1;
            if (vb) Console.WriteLine("attempting move" +
              " left to Node " + newNodeID);
            if (newNodeID >= this.tree.Count)
              break;  // attempt to fall out of tree
            if (this.tree[newNodeID].rows.Count == 0)
              break;  // move to invalid Node

            currNodeID = newNodeID;
            result = this.tree[currNodeID].predictedY;
            rule += " AND (column " + sc +
              " < " + sv + ")";
            //rule += " AND (" + sn +
            //  " < " + sv + ")";
          }
          else if (v >= sv)
          {
            newNodeID = (2 * currNodeID) + 2;
            if (vb) Console.WriteLine("attempting move" +
              " right to Node = " + newNodeID);
            if (newNodeID >= this.tree.Count)
              break;
            if (this.tree[newNodeID].rows.Count == 0)
              break;

            currNodeID = newNodeID;
            result = this.tree[currNodeID].predictedY;
            rule += " AND (column " + sc +
              " >= " + sv + ")";
            //rule += " AND (" + sn +
            //  " >= " + sv + ")";
          }
          else
          {
            if (vb) Console.WriteLine("Logic Error: " +
              "Unable to move left or right");
          }

          if (vb) Console.WriteLine("new node id = " +
            currNodeID);
        }

        // --------------------------------------------------

        else if (scKind == "C")  // Categorical column
        {

          int sv = (int)this.tree[currNodeID].splitVal;
          int v = (int)x[sc];
          if (vb) Console.WriteLine("Comparing " + sv +
            " in column " + sc + " with " + v);

          if (v != sv)
          {
            newNodeID = (2 * currNodeID) + 1;
            if (vb) Console.WriteLine("attempting move" +
              " left to Node " + newNodeID);
            if (newNodeID >= this.tree.Count)
              break;
            if (this.tree[newNodeID].rows.Count == 0)
              break;

            currNodeID = newNodeID;
            result = this.tree[currNodeID].predictedY;
            rule += " AND (column " + sc +
              " != " + sv + ")";
            //rule += " AND (" + sn +
            //  " != " + sv + ")";
          }
          else if (v == sv)
          {
            newNodeID = (2 * currNodeID) + 2;
            if (vb) Console.WriteLine("attempting move" +
              " right to Node = " + newNodeID);
            if (newNodeID >= this.tree.Count)
              break;
            if (this.tree[newNodeID].rows.Count == 0)
              break;

            currNodeID = newNodeID;
            result = this.tree[currNodeID].predictedY;
            rule += " AND (column " +
              sc + " == " + sv + ")";
            //rule += " AND (" +
            //  sn + " == " + sv + ")";
          }
          
          else
          {
            if (vb) Console.WriteLine("Logic Error: " +
              "Unable to move left or right");
          }

          if (vb) Console.WriteLine("new node id = " +
            currNodeID);
        }

      } // while

      if (vb) Console.WriteLine("\n" + rule);
      if (vb) Console.WriteLine("Predicted Y = " +
        result.ToString("F5"));

      return result;
    } // Predict

    // ------------------------------------------------------

    public double Accuracy(double[][] dataX,
      double[] dataY, double pctClose)
    {
      int numCorrect = 0;
      int numWrong = 0;
      for (int i = 0; i < dataX.Length; ++i)
      {
        double predY = Predict(dataX[i], verbose: false);
        double actualY = dataY[i];

        if (Math.Abs(predY - actualY) <
          Math.Abs(pctClose * actualY))
        {
          ++numCorrect;
        }
        else
        {
          ++numWrong;
        }
      }
      return (numCorrect * 1.0) / (numWrong + numCorrect);
    }

  } // DecisionTree class

  // ---------------------------------------------------------

  public class Utils
  {
    public static double[][] VecToMat(double[] vec,
      int rows, int cols)
    {
      // vector to row vec/matrix
      double[][] result = MatCreate(rows, cols);
      int k = 0;
      for (int i = 0; i < rows; ++i)
        for (int j = 0; j < cols; ++j)
          result[i][j] = vec[k++];
      return result;
    }

    public static double[][] MatCreate(int rows,
      int cols)
    {
      double[][] result = new double[rows][];
      for (int i = 0; i < rows; ++i)
        result[i] = new double[cols];
      return result;
    }

    static int NumNonCommentLines(string fn,
      string comment)
    {
      int ct = 0;
      string line = "";
      FileStream ifs = new FileStream(fn,
        FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
        if (line.StartsWith(comment) == false)
          ++ct;
      sr.Close(); ifs.Close();
      return ct;
    }

    public static double[][] MatLoad(string fn,
      int[] usecols, char sep, string comment)
    {
      // count number of non-comment lines
      int nRows = NumNonCommentLines(fn, comment);
      int nCols = usecols.Length;
      double[][] result = MatCreate(nRows, nCols);
      string line = "";
      string[] tokens = null;
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);

      int i = 0;
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        tokens = line.Split(sep);
        for (int j = 0; j < nCols; ++j)
        {
          int k = usecols[j];  // into tokens
          result[i][j] = double.Parse(tokens[k]);
        }
        ++i;
      }
      sr.Close(); ifs.Close();
      return result;
    }

    // ------------------------------------------------------

    public static double[] MatToVec(double[][] m)
    {
      int rows = m.Length;
      int cols = m[0].Length;
      double[] result = new double[rows * cols];
      int k = 0;
      for (int i = 0; i < rows; ++i)
        for (int j = 0; j < cols; ++j)
          result[k++] = m[i][j];

      return result;
    }

    public static void MatShow(double[][] m,
      int dec, int wid)
    {
      for (int i = 0; i < m.Length; ++i)
      {
        for (int j = 0; j < m[0].Length; ++j)
        {
          double v = m[i][j];
          if (Math.Abs(v) < 1.0e-8) v = 0.0; // hack
          Console.Write(v.ToString("F" +
            dec).PadLeft(wid));
        }
        Console.WriteLine("");
      }
    }

    public static void VecShow(int[] vec, int wid)
    {
      for (int i = 0; i < vec.Length; ++i)
        Console.Write(vec[i].ToString().PadLeft(wid));
      Console.WriteLine("");
    }

    public static void ListShow(List<int> list, int wid)
    {
      if (list.Count == 0)
        Console.WriteLine("EMPTY LIST ");
      for (int i = 0; i < list.Count; ++i)
        Console.Write(list[i].ToString().PadLeft(wid));
      Console.WriteLine("");
    }

    public static void VecShow(double[] vec,
      int dec, int wid, bool newLine)
    {
      for (int i = 0; i < vec.Length; ++i)
      {
        double x = vec[i];
        if (Math.Abs(x) < 1.0e-8) x = 0.0; 
        Console.Write(x.ToString("F" +
          dec).PadLeft(wid));
      }
      if (newLine == true)
        Console.WriteLine("");
    }

  } // Utils class

} // ns