/* Vgrid - Virtual grid program for radiology
   Copyright (C) 2020, 2021 Sonia Diaz Pacheco.

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#define _FILE_OFFSET_BITS 64

#include <algorithm>
#include <cctype>
#include <climits>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>

#include "vgrid.h"
#include "matrix.h"


void Matrix::compare_size( const Matrix & m ) const
  {
  if( height() != m.height() || width() != m.width() )
    {
    std::fprintf( stderr, "h1 = %ld, w1 = %ld, h2 = %ld, w2 = %ld\n",
                  height(), width(), m.height(), m.width() );
    throw Error( "Can't operate. Matrix sizes differ." );
    }
  }


void Matrix::test_size( const long rows, const long cols )
  {
  if( cols < 1 || rows < 1 )
    throw Error( "Matrix too small. Minimum size is 1x1." );
  if( LONG_MAX / cols < rows )
    throw Error( "Matrix too big. Number of elements will overflow 'long'." );
  }


// create an empty Matrix with all elements set to val
Matrix::Matrix( const long rows, const long cols, const double val )
  {
  test_size( rows, cols );
  data.resize( rows );
  for( unsigned long row = 0; row < data.size(); ++row )
    data[row].resize( cols, val );
  }


// create a Matrix from a two-dimensional vector
Matrix::Matrix( const std::vector< std::vector< double > > & d ) : data( d )
  {
  test_size( height(), width() );
  for( long row = 0; row < height(); ++row )
    if( data[row].size() != data[0].size() )
      throw Error( "Invalid matrix. Row sizes differ." );
  }

// create a Matrix from a two-dimensional array
Matrix::Matrix( const long rows, const long cols, const double array[] )
  {
  test_size( rows, cols );
  data.resize( rows );
  for( long row = 0; row < rows; ++row )
    for( long col = 0; col < cols; ++col )
      data[row].push_back( array[row*cols+col] );
  }


Matrix & Matrix::resize( const long rows, const long cols, const double val )
  {
  test_size( rows, cols );
  data.resize( rows );
  for( unsigned long row = 0; row < data.size(); ++row )
    data[row].resize( cols, val );
  return *this;
  }


/* Reduce image and color_info to 1/scale of original size per side.
   Discards up to floor(scale/2) pixels at each image border to make image
   size a multiple of scale.
   Keeps white letters separate from rest of image.
   (If 1/4 of white pixels or more, then white, else mean of non-whites).
*/
Matrix Matrix::reduce( const int scale, Color_info & color_info ) const
  {
  if( scale <= 1 || INT_MAX / scale < scale ||
      height() < scale || width() < scale ) return *this;
  const long rows = height() / scale, cols = width() / scale;	// new sizes
  const int roff = ( height() - rows * scale ) / 2;	// row offset
  const int coff = ( width() - cols * scale ) / 2;	// column offset
  double vmax = -INFINITY;			// max of pixels actually used
  for( long row = roff; row < rows * scale + roff; ++row )
    for( long col = coff; col < cols * scale + coff; ++col )
      if( vmax < data[row][col] ) vmax = data[row][col];
  vmax = nearbyint( vmax );
  const int s2 = scale * scale;		// number of pixels in each average
  std::vector< std::vector< double > > tmp( rows );
  for( long row = 0; row < rows; ++row )
    for( long col = 0; col < cols; ++col )
      {
      double mean = 0;
      const long r_ini = row * scale + roff;
      const long c_ini = col * scale + coff;
      int count = 0;				// count of non-white pixels
      for( long r = r_ini; r < r_ini + scale; ++r )
        for( long c = c_ini; c < c_ini + scale; ++c )
          {
          const double val = data[r][c];
          if( nearbyint( val ) < vmax ) { mean += val; ++count; }
          }
      if( count * 4 < s2 * 3 ) tmp[row].push_back( vmax );	// set pixel white
      else tmp[row].push_back( mean / count );	// set to mean of non-whites
      }
  if( color_info.data.size() == 3 &&
      (long)color_info.data[0].size() == height() &&
      (long)color_info.data[0][0].size() == width() )
    for( int plane = 0; plane < 3; ++plane )
      {
      for( long row = 0; row < rows; ++row )
        {
        const long r_ini = row * scale + roff;
        for( long col = 0; col < cols; ++col )
          {
          double mean = 0;
          const long c_ini = col * scale + coff;
          for( long r = r_ini; r < r_ini + scale; ++r )
            for( long c = c_ini; c < c_ini + scale; ++c )
              mean += color_info.data[plane][r][c];
          color_info.data[plane][row][col] = (unsigned)nearbyint( mean / s2 );
          }
        color_info.data[plane][row].resize( cols );
        }
      color_info.data[plane].resize( rows );
      }
  return Matrix( tmp );
  }


double Matrix::max() const
  {
  double vmax = -INFINITY;
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      if( vmax < data[row][col] ) vmax = data[row][col];
  return vmax;
  }


double Matrix::min() const
  {
  double vmin = INFINITY;
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      if( vmin > data[row][col] ) vmin = data[row][col];
  return vmin;
  }


double Matrix::mean() const
  {
  double sum = 0;
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      sum += data[row][col];
  return sum / size();
  }


double Matrix::norm1() const
  {
  double res = 0;
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      res += std::abs( data[row][col] );
  return res;
  }


double Matrix::eunorm() const
  {
  double res = 0;
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      res += data[row][col] * data[row][col];
  return std::sqrt( res );
  }


double Matrix::median( const long row, const long col, const int radius ) const
  {
  std::vector< double > neighborhood;
  const long first_row = std::max( 0L, row - radius );
  const long last_row = std::min( row + radius, height() - 1 );
  const long first_col = std::max( 0L, col - radius );
  const long last_col = std::min( col + radius, width() - 1 );
  for( long r = first_row; r <= last_row; ++r )
    for( long c = first_col; c <= last_col; ++c )
      neighborhood.push_back( data[r][c] );
  // pointer to central element (median)
  std::vector< double >::iterator median_ptr =
    neighborhood.begin() + ( ( neighborhood.size() - 1 ) / 2 );
  std::nth_element( neighborhood.begin(), median_ptr, neighborhood.end() );
  return *median_ptr;
  }


Matrix & Matrix::operator+=( const Matrix & m )
  {
  compare_size( m );
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      data[row][col] += m.data[row][col];
  return *this;
  }

Matrix & Matrix::operator+=( const double val )
  {
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      data[row][col] += val;
  return *this;
  }

Matrix & Matrix::operator-=( const Matrix & m )
  {
  compare_size( m );
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      data[row][col] -= m.data[row][col];
  return *this;
  }

Matrix & Matrix::operator-=( const double val )
  {
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      data[row][col] -= val;
  return *this;
  }

Matrix & Matrix::operator*=( const double val )
  {
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      data[row][col] *= val;
  return *this;
  }

Matrix & Matrix::operator/=( const double val )
  {
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      data[row][col] /= val;
  return *this;
  }


bool Matrix::equivalent( const Matrix & m, const double epsilon ) const
  {
  compare_size( m );
  for( long row = 0; row < height(); ++row )
    for( long col = 0; col < width(); ++col )
      if( std::abs( data[row][col] - m.data[row][col] ) > epsilon )
        return false;
  return true;
  }


// Limit median to pixels in image. (No border expansion).
Matrix Matrix::median_filter( const int radius ) const
  {
  const long inner_top = std::min( (long)radius, height() / 2 );
  const long inner_bottom = std::max( height() - 1 - radius, height() / 2 - 1 );
  const long inner_left = std::min( (long)radius, width() / 2 );
  const long inner_right = std::max( width() - 1 - radius, width() / 2 - 1 );
  Matrix tmp( height(), width() );			// empty matrix
  for( long col = 0; col < width(); ++col )
    {
    for( long row = 0; row < inner_top; ++row )
      tmp.data[row][col] = median( row, col, radius );	// filter top border
    for( long row = inner_bottom + 1; row < height(); ++row )
      tmp.data[row][col] = median( row, col, radius );	// filter bottom border
    }
  for( long row = 0; row < height(); ++row )
    {
    for( long col = 0; col < inner_left; ++col )
      tmp.data[row][col] = median( row, col, radius );	// filter left border
    for( long col = inner_right + 1; col < width(); ++col )
      tmp.data[row][col] = median( row, col, radius );	// filter right border
    }

  std::vector< double > neighborhood;
  for( long row = inner_top - radius; row <= inner_top + radius; ++row )
    for( long col = inner_left - radius; col <= inner_left + radius; ++col )
      neighborhood.push_back( data[row][col] );		// fill neighborhood

  // pointer to central element (median)
  std::vector< double >::iterator median_ptr =
    neighborhood.begin() + ( ( neighborhood.size() - 1 ) / 2 );
  for( long row = inner_top; row <= inner_bottom; ++row )
    {
    for( long col = inner_left; col <= inner_right; ++col )
      {
      std::nth_element( neighborhood.begin(), median_ptr, neighborhood.end() );
      tmp.data[row][col] = *median_ptr;		// set pixel to median
      if( col < inner_right )		// update neighborhood to the right
        for( int r = row - radius; r <= row + radius; ++r )
          {
          const double val = data[r][col - radius];
          std::vector< double >::iterator i =
            find( neighborhood.begin(), neighborhood.end(), val );
          *i = data[r][col + radius + 1];	// replace old with new pixel
          }
      }
    if( row >= inner_bottom ) break;
    // update neighborhood downwards
    for( int c = inner_right - radius; c <= inner_right + radius; ++c )
      {
      const double val = data[row - radius][c];
      std::vector< double >::iterator i =
        find( neighborhood.begin(), neighborhood.end(), val );
      *i = data[row + radius + 1][c];		// replace old with new pixel
      }
    ++row;
    for( long col = inner_right; col >= inner_left; --col )
      {
      std::nth_element( neighborhood.begin(), median_ptr, neighborhood.end() );
      tmp.data[row][col] = *median_ptr;		// set pixel to median
      if( col > inner_left )		// update neighborhood to the left
        for( int r = row - radius; r <= row + radius; ++r )
          {
          const double val = data[r][col + radius];
          std::vector< double >::iterator i =
            find( neighborhood.begin(), neighborhood.end(), val );
          *i = data[r][col - radius - 1];	// replace old with new pixel
          }
      }
    if( row >= inner_bottom ) break;
    // update neighborhood downwards
    for( int c = inner_left - radius; c <= inner_left + radius; ++c )
      {
      const double val = data[row - radius][c];
      std::vector< double >::iterator i =
        find( neighborhood.begin(), neighborhood.end(), val );
      *i = data[row + radius + 1][c];		// replace old with new pixel
      }
    }
  return tmp;
  }


void Matrix::print() const
  {
  for( long row = 0; row < height(); ++row )
    {
    for( long col = 0; col < width(); ++col )
      printf( " %3g", data[row][col] );		// use %3.36f for full value
    std::fputc( '\n', stdout );
    }
  }
