///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
///
/// =========================================================================
//
#include "rheolef/form.h"
#include "rheolef/form_element.h"
#include "form_assembly.icc"
#include "rheolef/field_expr.h"
#include "rheolef/field_expr_ops.h"


namespace rheolef { 
using namespace std;

template<class T, class M>
form_basic<T,M>::form_basic (
    const space_type& X, 
    const space_type& Y,
    const std::string& name)
  : _X(X), 
    _Y(Y), 
    uu(),
    ub(),
    bu(),
    bb()
{
    bool X_is_on_domain = (_X.get_geo().data().variant() == geo_abstract_base_rep<float_type>::geo_domain);
    bool Y_is_on_domain = (_Y.get_geo().data().variant() == geo_abstract_base_rep<float_type>::geo_domain);
    geo_basic<T,M> omega;
    if (X_is_on_domain == Y_is_on_domain) {
      omega = _X.get_geo();
    } else {
      omega = _X.get_geo().get_background_geo();
    }
    form_element<T,M> form_e (name, _X.get_element(), _Y.get_element(), omega);

    if (X_is_on_domain == Y_is_on_domain) {
      // example:
      //    form m (V,V,"mass");    e.g. \int_\Omega u v dx
      //    form m (W,W,"mass");    e.g. \int_\Gamma u v ds
      // with:
      //    geo omega ("square");
      //    geo gamma = omega["boundary"];
      //    V = space(omega,"P1");
      //    W = space(gamma,"P1");
      assembly (form_e, _X.get_geo(), _Y.get_geo());
    } else if (X_is_on_domain) {
      // example:
      //    form m (W,V,"mass");   e.g. \int_\Gamma u trace(v) ds
      check_macro (X.get_geo().get_background_geo() == Y.get_geo(), 
	"form between incompatible geo " << X.get_geo().name() << " and " << Y.get_geo().name());
      const geo_basic<T,M>& Y_gamma = _X.get_geo().get_background_domain();
      bool X_geo_is_background = false;
      assembly (form_e, _X.get_geo(), Y_gamma, X_geo_is_background);
    } else { // Y_is_on_domain
      // example:
      //    form m (V,W,"mass");  e.g. \int_\Gamma trace(u) v ds	
      check_macro (Y.get_geo().get_background_geo() == X.get_geo(),
	"form between incompatible geo " << X.get_geo().name() << " and " << Y.get_geo().name());
      const geo_basic<T,M>& X_gamma = _Y.get_geo().get_background_domain();
      bool X_geo_is_background = true;
      assembly (form_e, X_gamma, _Y.get_geo(), X_geo_is_background);
    }
}
template<class T, class M>
form_basic<T,M>::form_basic (
    const space_type& X, 
    const space_type& Y,
    const std::string& name,
    const geo_basic<T,M>& gamma)
  : _X(X), 
    _Y(Y), 
    uu(),
    ub(),
    bu(),
    bb()
{
    // example:
    //    form m (V,V,"mass",gamma);  e.g. \int_\Gamma trace(u) trace(v) ds	
    // with:
    //    geo omega ("square");
    //    geo gamma = omega["boundary"];
    //    V = space(omega,"P1");
    const geo_basic<T,M>& omega = _X.get_geo().get_background_geo();
    form_element<T,M> form_e (name, _X.get_element(), _Y.get_element(), omega);
    assembly (form_e, gamma, gamma);
}
// ----------------------------------------------------------------------------
// blas2
// ----------------------------------------------------------------------------
template<class T, class M>
field_basic<T,M>
form_basic<T,M>::operator* (const field_basic<T,M>& xh) const
{
    // TODO: verif des tailles des espaces ET de tous les vecteurs
    // si pas les memes cl, on pourrait iterer sur la form... + complique
    field_basic<T,M> yh (_Y, T(0));
    yh.u = uu*xh.u + ub*xh.b;
    yh.b = bu*xh.u + bb*xh.b;
    return yh;
}
template<class T, class M>
typename form_basic<T,M>::float_type
form_basic<T,M>::operator() (const field_basic<T,M>& uh, const field_basic<T,M>& vh) const
{
    return dot (operator*(uh), vh);
}
// ----------------------------------------------------------------------------
// output: print all four csr as a large sparse matrix in matrix-market format
// ----------------------------------------------------------------------------

struct id {
  size_t operator() (size_t i) { return i; }
};
template<class T, class M, class Permutation1, class Permutation2>
static
void
merge (
    asr<T,M>& a, 
    const csr<T,M>& m,
    Permutation1 dis_im2dis_idof,
    Permutation2 dis_jm2dis_jdof)
{
    typedef typename form_basic<T,M>::size_type size_type;
    size_type i0 = m.row_ownership().first_index();
    size_type j0 = m.col_ownership().first_index();
    typename csr<T,M>::const_iterator ia = m.begin(); 
    for (size_type im = 0, nrow = m.nrow(); im < nrow; im++) {
      size_type dis_im = im + i0;
      size_type dis_idof = dis_im2dis_idof (dis_im);
      for (typename csr<T,M>::const_data_iterator p = ia[im]; p != ia[im+1]; p++) {
	const size_type& jm  = (*p).first;
	const T&         val = (*p).second;
	size_type dis_jm     = jm + j0;
	size_type dis_jdof   = dis_jm2dis_jdof (dis_jm);
        a.dis_entry (dis_idof, dis_jdof) = val;
      }
    }
#ifdef _RHEOLEF_HAVE_MPI
    typename csr<T,M>::const_iterator ext_ia = m.ext_begin(); 
    for (size_type im = 0, nrow = m.nrow(); im < nrow; im++) {
      size_type dis_im = im + i0;
      size_type dis_idof = dis_im2dis_idof (dis_im);
      long int ext_size_im = std::distance(ext_ia[im],ext_ia[im+1]);
      for (typename csr<T,M>::const_data_iterator p = ext_ia[im]; p != ext_ia[im+1]; p++) {
	const size_type& jext = (*p).first;
	const T&         val  = (*p).second;
	size_type dis_jm      = m.jext2dis_j (jext);
	size_type dis_jdof    = dis_jm2dis_jdof (dis_jm);
        a.dis_entry (dis_idof, dis_jdof) = val;
      }
    }
#endif // _RHEOLEF_HAVE_MPI
}
template<class T, class M>
odiststream& 
form_basic<T,M>::put (odiststream& ops, bool show_partition) const
{
    // put all on io_proc 
    size_type dis_nrow = get_second_space().dis_size();
    size_type dis_ncol =  get_first_space().dis_size();
    size_type io_proc = odiststream::io_proc();
    size_type my_proc = comm().rank();
    distributor io_row_ownership (dis_nrow, comm(), (my_proc == io_proc ? dis_nrow : 0));
    distributor io_col_ownership (dis_ncol, comm(), (my_proc == io_proc ? dis_ncol : 0));
    asr<T,M> a (io_row_ownership, io_col_ownership);

    if (show_partition) {
        merge (a, uu, id(), id());
        merge (a, ub, id(), id());
        merge (a, bu, id(), id());
        merge (a, bb, id(), id());
    } else {
        error_macro ("not yet");
    }
    a.dis_entry_assembly();
    ops << "%%MatrixMarket matrix coordinate real general" << std::endl
        << dis_nrow << " " << dis_ncol << " " << a.dis_nnz() << std::endl
        << a;
    return ops;
}
template <class T, class M>
void
form_basic<T,M>::dump (std::string name) const
{
    uu.dump (name + "-uu");
    ub.dump (name + "-ub");
    bu.dump (name + "-bu");
    bb.dump (name + "-bb");
}
// ----------------------------------------------------------------------------
// instanciation in library
// ----------------------------------------------------------------------------
#ifndef _RHEOLEF_HAVE_MPI
// TODO: compile it also in distributed, but problems yet...
template class form_basic<Float,sequential>;
#endif // _RHEOLEF_HAVE_MPI

#ifdef _RHEOLEF_HAVE_MPI
template class form_basic<Float,distributed>;
#endif // _RHEOLEF_HAVE_MPI

}// namespace rheolef
