﻿using System;
using System.Collections.Generic;
using System.IO;

// code based on the Wikipedia article algorithm

namespace ClusterSOM
{
  class ClusterSOMProgram
  {
    static void Main(string[] args)
    {
      Console.WriteLine("\nBegin self-organizing" +
        " map (SOM) clustering using C#");

      // 1. load data
      Console.WriteLine("\nLoading 12-item Penguin subset ");
      string fn = "..\\..\\..\\Data\\penguin_12.txt";
      double[][] X = ClusterSOM.MatLoad(fn,
        new int[] { 1, 2, 3, 4 }, ',', "#");
      Console.WriteLine("\nX: ");
      ClusterSOM.MatShow(X, 1, 8, true);

      //int[] y = ClusterSOM.VecLoad(fn, 0, "#");
      //Console.WriteLine("\ny labels: ");
      //ClusterSOM.VecShow(y, 3);

      // 2. standardize data
      double[][] stdX = ClusterSOM.MatStandard(X);
      Console.WriteLine("\nStandardized data: ");
      ClusterSOM.MatShow(stdX, 4, 9, true);

      // 3. create ClusterSOM object and cluster
      int k = 3;
      double lrnRateMax = 0.50;
      int stepsMax = 1000;
      Console.WriteLine("\nSetting num clusters k = " + k);
      Console.WriteLine("Setting  lrnRateMax = " +
        lrnRateMax.ToString("F2"));
      Console.WriteLine("Setting stepsMax = " + stepsMax);

      Console.WriteLine("\nComputing SOM clustering ");
      ClusterSOM som = new ClusterSOM(stdX, k, seed: 0);
      som.Cluster(lrnRateMax, stepsMax);
      Console.WriteLine("Done ");

      // 4. show the SOM map and mapping
      Console.WriteLine("\nSOM map nodes: ");
      for (int kk = 0; kk < k; ++kk)
      {
        Console.Write("k = " + kk + ": ");
        ClusterSOM.VecShow(som.map[0][kk], 4, 9);
      }

      Console.WriteLine("\nSOM mapping: ");
      for (int kk = 0; kk < k; ++kk)
      {
        Console.Write("k = " + kk + ":    ");
        ClusterSOM.ListShow(som.mapping[0][kk]);
      }

      double[][] betweenDists = som.GetBetweenNodeDists();
      Console.WriteLine("\nBetween map node distances: ");
      ClusterSOM.MatShow(betweenDists, 2, 6, true);

      // 5. show clustering result
      Console.WriteLine("\nclustering: ");
      int[] clustering = som.GetClustering();
      ClusterSOM.VecShow(clustering, wid: 3);

      Console.WriteLine("\nEnd SOM clustering demo");
      Console.ReadLine();
    } // Main

    // ------------------------------------------------------

  } // Program

  public class ClusterSOM
  {
    public int k;  // number clusters
    public double[][] data;  // data to cluster
    public double[][][] map;  // [r][c][vec]
    public List<int>[][] mapping;  // [r][c](List indices)
    public Random rnd;  // to initialize map cells

    // ------------------------------------------------------
    // methods: ctor(), Cluster(), GetClustering(),
    // GetBetweenNodeDists()
    // ------------------------------------------------------

    public ClusterSOM(double[][] X, int k, int seed)
    {
      int nRows = 1; // for map
      int nCols = k; // for map
      this.k = k;
      //this.dim = X[0].Length;
      int dim = X[0].Length;

      this.rnd = new Random(seed);
      this.data = X;
      //this.n = X.Length;

      // map is 1-by-k matrix
      this.map =
        new double[nRows][][];  // [r][c][vec]
      for (int i = 0; i < nRows; ++i)
      {
        this.map[i] = new double[nCols][];
        for (int j = 0; j < nCols; ++j)
        {
          this.map[i][j] = new double[dim];
          for (int d = 0; d < dim; ++d)
            this.map[i][j][d] = this.rnd.NextDouble();
        }
      }

      this.mapping =
        new List<int>[nRows][]; // [r][c][lst]
      for (int i = 0; i < nRows; ++i)
      {
        this.mapping[i] = new List<int>[nCols];
        for (int j = 0; j < nCols; ++j)
          this.mapping[i][j] = new List<int>();
      }
    } // ctor

    // ------------------------------------------------------

    public void Cluster(double lrnRateMax, int stepsMax)
    {
      int n = this.data.Length;
      int dim = this.data[0].Length;
      int nRows = 1; // for map
      int nCols = this.k; // for map
      int rangeMax = nRows + nCols;

      // compute the map
      for (int step = 0; step < stepsMax; ++step)
      {
        // show progress 5 + 1 times
        if (step % (int)(stepsMax / 5) == 0)
        {
          Console.Write("map build step = " +
            step.ToString().PadLeft(4));
          double sum = 0.0;  // sum Euclidean distances
          for (int ix = 0; ix < n; ++ix)
          {
            int[] RC = ClosestNode(ix);
            int kk = RC[1];  // RC[0] is always 0
            double[] item = this.data[ix];
            double[] node = this.map[0][kk];
            double dist = EucDist(item, node);
            sum += dist;  // accumulate
          }
          Console.WriteLine("  |  SED = " +
            sum.ToString("F4").PadLeft(9));
        }

        double pctLeft = 1.0 - ((step * 1.0) / stepsMax);
        int currRange = (int)(pctLeft * rangeMax);
        double currLrnRate = pctLeft * lrnRateMax;
        // Pick random data index
        int idx = this.rnd.Next(0, n);
        // Get (row,col) of closest map node -- 'bmu'
        int[] bmuRC = ClosestNode(idx);
        // Move each map mode closer to the bmu
        for (int i = 0; i < nRows; ++i)
        {
          for (int j = 0; j < nCols; ++j)
          {
            if (ManDist(bmuRC[0],
              bmuRC[1], i, j) <= currRange)
            {
              for (int d = 0; d < dim; ++d)
                this.map[i][j][d] = this.map[i][j][d] +
                  currLrnRate * (this.data[idx][d] -
                  this.map[i][j][d]);
            }
          } // j
        } // i
      } // step
      // map has been created

      // compute mapping
      for (int idx = 0; idx < n; ++idx)
      {
        // node map coords of node closest to data(idx)
        int[] rc = ClosestNode(idx);
        int r = rc[0]; int c = rc[1];
        this.mapping[r][c].Add(idx);
      }

      // results in this.mapping
      return;
    } // Cluster()

    // ------------------------------------------------------

    public int[] GetClustering()
    {
      // cluster ID for every data item
      int n = this.data.Length;
      int[] result = new int[n];  // ID for each item
      for (int kk = 0; kk < this.k; ++kk)
      {
        for (int i = 0; i < this.mapping[0][kk].Count; ++i)
        {
          int dataIdx = this.mapping[0][kk][i];
          result[dataIdx] = kk;
        }
      }
      return result;
    }

    public double[][] GetBetweenNodeDists()
    {
      // between map node distances
      double[][] result = new double[this.k][];
      for (int kk = 0; kk < this.k; ++kk)
        result[kk] = new double[this.k];
      for (int i = 0; i < this.k; ++i)
      {
        for (int j = i; j < this.k; ++j)
        {
          double dist =
            EucDist(this.map[0][i], this.map[0][j]);
          result[i][j] = dist;
          result[j][i] = dist;
        }
      }
      return result;
    }

    // ------------------------------------------------------
    // helpers: ManDist(), EucDist(), ClosestNode()
    // ------------------------------------------------------

    private static int ManDist(int r1, int c1,
      int r2, int c2)
    {
      // Manhattan distance between two SOM map cells
      return Math.Abs(r1 - r2) + Math.Abs(c1 - c2);
    }

    // ------------------------------------------------------

    private static double EucDist(double[] v1,
      double[] v2)
    {
      // Euclidean distance between two data items
      double sum = 0;
      for (int i = 0; i < v1.Length; ++i)
        sum += (v1[i] - v2[i]) * (v1[i] - v2[i]);
      return Math.Sqrt(sum);
    }

    // ------------------------------------------------------

    private int[] ClosestNode(int idx)
    {
      // r,c coords in map of node closest to data[idx]
      double smallDist = double.MaxValue;
      int[] result = new int[] { 0, 0 };  // (row, col)
      for (int i = 0; i < this.map.Length; ++i)
      {
        for (int j = 0; j < this.map[0].Length; ++j)
        {
          double dist = EucDist(this.data[idx],
            this.map[i][j]);
          if (dist < smallDist)
          {
            smallDist = dist;
            result[0] = i;
            result[1] = j;
          }
        }
      }
      //Console.WriteLine(result[0]);
      //Console.WriteLine(result[1]);
      //Console.ReadLine();
      return result;
    }

    // ------------------------------------------------------

    // misc. public utility functions for convenience
    // MatLoad(), VecLoad(), MatShow(), VecShow(),
    // ListShow(), MatStandard()

    public static double[][] MatLoad(string fn,
      int[] usecols, char sep, string comment)
    {
      // count number of non-comment lines
      int nRows = 0;
      string line = "";
      FileStream ifs = new FileStream(fn, FileMode.Open);
      StreamReader sr = new StreamReader(ifs);
      while ((line = sr.ReadLine()) != null)
        if (line.StartsWith(comment) == false)
          ++nRows;
      sr.Close(); ifs.Close();

      // make result matrix
      int nCols = usecols.Length;
      double[][] result = new double[nRows][];
      for (int r = 0; r < nRows; ++r)
        result[r] = new double[nCols];

      line = "";
      string[] tokens = null;
      ifs = new FileStream(fn, FileMode.Open);
      sr = new StreamReader(ifs);

      int i = 0;
      while ((line = sr.ReadLine()) != null)
      {
        if (line.StartsWith(comment) == true)
          continue;
        tokens = line.Split(sep);
        for (int j = 0; j < nCols; ++j)
        {
          int k = usecols[j];  // into tokens
          result[i][j] = double.Parse(tokens[k]);
        }
        ++i;
      }
      sr.Close(); ifs.Close();
      return result;
    }

    // ------------------------------------------------------

    public static int[] VecLoad(string fn, int usecol,
      string comment)
    {
      char dummySep = ',';
      double[][] tmp = MatLoad(fn, new int[] { usecol },
        dummySep, comment);
      int n = tmp.Length;
      int[] result = new int[n];
      for (int i = 0; i < n; ++i)
        result[i] = (int)tmp[i][0];
      return result;
    }

    // ------------------------------------------------------

    public static void MatShow(double[][] M, int dec,
      int wid, bool showIndices)
    {
      double small = 1.0 / Math.Pow(10, dec);
      for (int i = 0; i < M.Length; ++i)
      {
        if (showIndices == true)
        {
          int pad = M.Length.ToString().Length;
          Console.Write("[" + i.ToString().
            PadLeft(pad) + "]");
        }
        for (int j = 0; j < M[0].Length; ++j)
        {
          double v = M[i][j];
          if (Math.Abs(v) < small) v = 0.0;
          Console.Write(v.ToString("F" + dec).
            PadLeft(wid));
        }
        Console.WriteLine("");
      }
    }

    // ------------------------------------------------------

    public static void VecShow(int[] vec, int wid)
    {
      int n = vec.Length;
      for (int i = 0; i < n; ++i)
        Console.Write(vec[i].ToString().PadLeft(wid));
      Console.WriteLine("");
    }

    // ------------------------------------------------------

    public static void VecShow(double[] vec, int decimals,
      int wid)
    {
      int n = vec.Length;
      for (int i = 0; i < n; ++i)
        Console.Write(vec[i].ToString("F" + decimals).
          PadLeft(wid));
      Console.WriteLine("");
    }

    // ------------------------------------------------------

    public static void ListShow(List<int> lst)
    {
      int n = lst.Count;
      for (int i = 0; i < n; ++i)
      {
        Console.Write(lst[i] + " ");
      }
      Console.WriteLine("");
    }

    // ------------------------------------------------------

    public static double[][] MatStandard(double[][] data)
    //public static double[][] MatStandard(double[][] data,
    //  out double[] means, out double[] stds)
    {
      // scikit style z-score biased normalization
      int nRows = data.Length;
      int nCols = data[0].Length;

      // make result matrix
      double[][] result = new double[nRows][];
      for (int r = 0; r < nRows; ++r)
        result[r] = new double[nCols];

      // compute means
      double[] mns = new double[nCols];
      for (int j = 0; j < nCols; ++j)
      {
        double sum = 0.0;
        for (int i = 0; i < nRows; ++i)
          sum += data[i][j];
        mns[j] = sum / nRows;
      } // j

      // compute std devs
      double[] sds = new double[nCols];
      for (int j = 0; j < nCols; ++j)
      {
        double sum = 0.0;
        for (int i = 0; i < nRows; ++i)
          sum += (data[i][j] - mns[j]) *
            (data[i][j] - mns[j]);
        sds[j] = Math.Sqrt(sum / nRows);  // biased
      } // j

      // normalize
      for (int j = 0; j < nCols; ++j)
      {
        for (int i = 0; i < nRows; ++i)
          result[i][j] =
            (data[i][j] - mns[j]) / sds[j];
      } // j

      //means = mns;
      //stds = sds;

      return result;
    }

    // ------------------------------------------------------

  } // class SOM

} // ns