///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
///
/// =========================================================================
#include "basis_symbolic_hermite.h"
using namespace rheolef;
using namespace std;
using namespace GiNaC;

basis_symbolic_hermite_on_geo::size_type 
basis_symbolic_hermite_on_geo::size_family(reference_element::dof_family_type family) const
 {
   if (family==reference_element::Lagrange) return _node.size();
   else if (family==reference_element::Hermite) return _node_derivative.size();
   else if (family==reference_element::dof_family_max) return size();
   else return 0;
 }
basis_symbolic_hermite_on_geo::value_type 
basis_symbolic_hermite_on_geo::eval (
    const polynom_type&    p, 
    const basic_point<ex>& xi, 
    size_type              d) const
{
    ex expr = p;
    // if (d > 0)  -- even if d==0 we need to consider dimension x (first derivative at points) 
    expr = expr.subs(x == ex(xi[0]));
    if (d > 1) expr = expr.subs(y == ex(xi[1]));
    if (d > 2) expr = expr.subs(z == ex(xi[2]));
    return expr;
}
matrix
basis_symbolic_hermite_on_geo::confluent_vandermonde_matrix (
    const vector<ex>&          p,
    size_type                  d) const
{
    unsigned int m = _node.size();
    unsigned int n = m + _node_derivative.size();
    matrix   vdm (n, n);

    for (unsigned int i = 0; i < m; i++) {
    
        const basic_point<ex>& xi = _node[i];
        for (unsigned int j = 0; j < n; j++) {
          vdm(i,j) = eval (p[j], xi, d);
        }
    }
    for (unsigned int i = m; i < n; i++) {
        const basic_point<ex>& xi = _node_derivative[i-m];
        for (unsigned int j = 0; j < n; j++) {
          vdm(i,j) = eval 
	  	( (_derivative_sign[i-m]*(p[j]).diff(_derivative_variable[i-m])).expand().normal(),
		  xi, d);
        }
    }
    return vdm;
}

/* 
 * Build the Hermite basis, associated to nodes.
 *
 * input: node[m], node_derivatives[n-m], derivative_variable[n-m] and polynomial_basis[n]
 * output: node_basis[n]
 *   such that node_basis[i](node[k]) = kronecker[i][k]
 *             diff(node_basis[i],derivative_variable[k-m])(node_derivative[k-m]) = kronecker[i][k]
 *
 * algorithm: let:
 *   b_i = \sum_j a_{i,j} p_j
 *  where:
 *       p_j = polynomial basis [1, x, y, ..]
 *       b_i = basis associated to node :
 *  we want:
 *   b_i(x_k)       = delta_{i,k},    k=0..m-1
 *   b'_i(x'_{k-m}) = delta_{i,k-m},  k=m..n
 *   <=>
 *   a_{i,j} p_j(x_k)   = \delta_{i,k}
 *   a_{i,j} p'_j(x'_k) = \delta_{i,k}
 * Let A = (a_{k,j})_{i,j} and c_{i,j} = (p_j(x_i) if i<m, p'_j(x'_{i-m}) else)
 * Then a_{i,j} c_{k,j} = delta_{i,k}
 *        <=>
 *    A = C^{-T}
 */

void
basis_symbolic_hermite_on_geo::make_node_basis()
{
  assert_macro (_node.size() + _node_derivative.size() == _poly.size(),
	"incompatible node set size (" << _node.size() << "+" << _node_derivative.size()
	<< ") and polynomial basis size (" << _poly.size() << ").");

  const size_type d = _hat_K.dimension();
  const size_type n = _poly.size();
  const size_type m = _node.size();

//#ifdef TO_CLEAN
  warning_macro ("node.size = " << _node.size());
  for (size_type i = 0; i < _node.size(); i++) {
      cerr << "node("<<i<<") = " << _node[i] << endl;
  }
  warning_macro ("poly.size = " << _poly.size());
  for (size_type i = 0; i < _poly.size(); i++) {
      cerr << "poly("<<i<<") = " << _poly[i] << endl;
  }
//#endif // TO_CLEAN
  // Vandermonde matrix vdm(i,j) = pj(xi), pj'(xi)
  matrix vdm = confluent_vandermonde_matrix (_poly, d);
  cerr << vdm << endl;
  ex det_vdm = determinant(vdm);
  check_macro (det_vdm != 0, "basis unisolvence failed on element `" 
		  << _hat_K.name() << "'");
  matrix inv_vdm = vdm.inverse();
  cerr << inv_vdm << endl;
 
  // basis := trans(a)*poly
  _basis.resize(n);
  for (size_type i = 0; i < n; i++) {
    polynom_type s = 0;
    for (size_type j = 0; j < n; j++) {
      s += inv_vdm(j,i) * _poly[j];
    }
    s = expand(s);
    s = normal(s);
    _basis[i] = s;
  }
  warning_macro ("basis.size = " << _basis.size());
  for (size_type i = 0; i < _basis.size(); i++) {
      cerr << "basis("<<i<<") = " << _basis[i] << endl;
  }
  // check:
  matrix vdm_l = confluent_vandermonde_matrix (_basis, d);
  int ndigit10 = Digits;

  numeric tol = ex_to<numeric>(pow(10.,-ndigit10/2.));
  int status = 0;
  for (size_type i = 0; i < n; i++) {
    for (size_type j = 0; j < n; j++) {
      if ((i == j && abs(vdm_l(i,j) - 1) > tol) ||
          (i != j && abs(vdm_l(i,j))     > tol)    ) {
  	  error_macro ("Lagrange polynom check failed.");
      }
    }
  }

  // Family of basis functions: Lagrange for i<m, else Hermite
  _poly_family.resize(n, reference_element::Hermite);
  for (size_type i = 0; i < m; i++) _poly_family[i] = reference_element::Lagrange;

  // derivatives of the basis
  Float d_der=d;
  if (d == 0) d_der=1;  // even if d==0 we need to consider dimension x (first derivative at points) 
  _grad_basis.resize(n);
  for (size_type i = 0; i < n; i++) {
    if (d_der > 0) _grad_basis [i][0] = _basis[i].diff(x).expand().normal();
    if (d_der > 1) _grad_basis [i][1] = _basis[i].diff(y).expand().normal();
    if (d_der > 2) _grad_basis [i][2] = _basis[i].diff(z).expand().normal();
  }
  warning_macro ("grad_basis.size = " << _grad_basis.size());
  for (size_type i = 0; i < _basis.size(); i++) {
      cerr << "grad_basis("<<i<<") = [";
      for (size_type j = 0; j < d_der; j++) {
          cerr << _grad_basis [i][j];
	  if (j == d_der-1) cerr << "]" << endl;
	  else cerr << ", ";
      }
  }

  // Hessian of the basis
  basic_point<GiNaC::ex> ini_g;
  basic_point<basic_point<GiNaC::ex> > ini_h(ini_g,ini_g,ini_g);
  _hessian_basis.resize(n, ini_h);
  for (size_type i = 0; i < n; i++) {
    if (d_der > 0) _hessian_basis [i][0][0] = _basis[i].diff(x).diff(x).expand().normal();
    if (d_der > 1) _hessian_basis [i][0][1] = _basis[i].diff(x).diff(y).expand().normal();
    if (d_der > 2) _hessian_basis [i][0][2] = _basis[i].diff(x).diff(z).expand().normal();
    if (d_der > 1) _hessian_basis [i][1][0] = _hessian_basis [i][0][1];
    if (d_der > 1) _hessian_basis [i][1][1] = _basis[i].diff(y).diff(y).expand().normal();
    if (d_der > 2) _hessian_basis [i][2][0] = _hessian_basis [i][0][2];
    if (d_der > 2) _hessian_basis [i][2][1] = _hessian_basis [i][1][2];
    if (d_der > 2) _hessian_basis [i][2][2] = _basis[i].diff(z).diff(z).expand().normal();
  }
  warning_macro ("hessian_basis.size = " << _hessian_basis.size());
  for (size_type i = 0; i < _basis.size(); i++) {
      for (size_type k = 0; k < d_der; k++) {
      cerr << "hessian_basis("<<i<<")["<<k<<"] = [";
      for (size_type j = 0; j < d_der; j++) {
          cerr << _hessian_basis [i][k][j];
	  if (j == d_der-1) cerr << "]" << endl;
	  else cerr << ", ";
      }
      }
  }

}
