// -*- C++ -*-

// Copyright 2006-2007 Deutsches Forschungszentrum fuer Kuenstliche Intelligenz
// or its licensors, as applicable.
//
// You may not use this file except under the terms of the accompanying license.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you
// may not use this file except in compliance with the License. You may
// obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Project: ocr-bpnet - neural network classifier
// File: classify-chars.cc
// Purpose: train/recognize characters with feature extraction and classmap information
// Responsible: kapry
// Reviewer: rangoni
// Primary Repository:
// Web Sites: www.iupr.org, www.dfki.de

#include "colib.h"
#include "classmap.h"
#include "charlib.h"
#include "classify-chars.h"
#include "feature-extractor.h"
#include "narray-io.h"
#include "additions.h"
#include "confusion-matrix.h"
#include "bpnet.h"
#include "didegrade.h"

using namespace ocropus;
using namespace colib;
using namespace iupr_bpnet;

namespace {

    // log the final confusion matrix for the training dataset
    Logger logger_confusion_map_train("conf_map_train");
    // log the final confusion matrix for the testing dataset
    Logger logger_confusion_map_test("conf_map_test");
    // log the final confusion matrix in a list-layout for the training dataset
    Logger logger_conf_map_train_reduced("conf_map_train_red");
    // log the final confusion matrix in a list-layout for the testing dataset
    Logger logger_conf_map_test_reduced("conf_map_test_red");

    int nb_max_features = 9;
    FeatureExtractor::FeatureType possible_features[] = {
        FeatureExtractor::BAYS,
        FeatureExtractor::GRAD,
        FeatureExtractor::INCL,
        FeatureExtractor::IMAGE,
        FeatureExtractor::SKEL,
        FeatureExtractor::SKELPTS,
        FeatureExtractor::RELSIZE,
        FeatureExtractor::SKEL2,
        FeatureExtractor::HISTO
    };
    char possible_feature_names[][10] = {"BAYS","GRAD","INCL","IMAGE","SKEL",
                            "SKELPTS","RELSIZE","SKEL2","HISTO"};
    char OUR_PREVIOUS_FEATURES[] = "111111100";
    char OUR_FEATURES[] = "111101110";
}

struct LineCharacterClassifier : ICharacterClassifier {
    autodel<IClassifier>        classifier;
    autodel<FeatureExtractor>   extractor;
    ClassMap map;

    objlist<nustring>   variants;
    floatarray          costs;

    bool        output_garbage;
    int         ninput;
    bool        init;

    intarray    whichfeatures;
    int         dim_f_x, dim_f_y;
    bool        usedegrade;             // if set to true, the characters and garbages are degraded on the fly when feeding the classifier
    float       garbage_portion;        // 0.0 means no garbage, 1.0 means all garbage

    void append_our_features(FeatureExtractor &extractor,
                             floatarray &features,
                             bytearray &image,
                             bool with_line_info = true) {
        for(int i=0;i<nb_max_features;i++) {
            if(whichfeatures[i] == 1) {
                extractor.appendFeatures(features, image, possible_features[i]);
            }
        }
        if(with_line_info)
            extractor.appendFeatures(features, image, FeatureExtractor::POS);
    }

    virtual void setImage(bytearray &in) {
        variants.clear();
        costs.clear();
        floatarray features;
        append_our_features(*extractor, features, in, false);

        floatarray result;
        classifier->score(result, features);

        for(int i=0; i<result.length();i++) {
            if(output_garbage || map.get_ascii(i) != GARBAGE) {  // if not garbage; is this a kluge?
                variants.push().push(nuchar(map.get_ascii(i)));
                costs.push(-log(result[i]));
            }
        }
    }

    virtual void setImage(bytearray &in,
                          int baseline,
                          int xheight_y,
                          int descender_y,
                          int ascender_y) {
        variants.clear();
        costs.clear();
        floatarray features;
        extractor->setLineInfo(baseline,xheight_y-baseline);
        append_our_features(*extractor, features, in);
        floatarray result;
        classifier->score(result, features);

        for(int i=0;i<result.length();i++) {
            if(output_garbage || map.get_ascii(i) != GARBAGE) {  // if not garbage; is this a kluge?
                variants.push().push(nuchar(map.get_ascii(i)));
                costs.push(-log(result[i]));
            }
        }
    }

    void outputfeature(bytearray &in, const char* where) {
        floatarray features;
        append_our_features(*extractor, features, in, false);
        bytearray png;
        extractor->feat2image(png, features);
        additions::save_char(png,where);
    }

    virtual void cls(nustring &result, int i) {
        copy(result, variants[i]);
    }

    virtual float cost(int i) {
        return costs[i];
    }

    virtual int length() {
        ASSERT(variants.length() == costs.length());
        return variants.length();
    }

    virtual void load(FILE *stream) {
        // the load method does not use the read_check_point since it is
        // necessary to handle previously trained and saved classifiers which
        // do not include 'feature' information
        map.load(stream);
        classifier->load(stream);

        init = true;
        char aux[1024];aux[0]='\0';
        fscanf(stream,"%s",aux);
        if(strcmp(aux,"dim_features") == 0) {  // if feature size information
            fscanf(stream,"%d",&dim_f_x);
            fscanf(stream,"%d",&dim_f_y);
            fscanf(stream,"%s",aux);
            if(strcmp(aux,"features") == 0) {  // if feature name information
                int tmp;
                fscanf(stream,"%d",&tmp);
                ASSERT(tmp<=nb_max_features);
                for(int i=0; i<nb_max_features;++i) {
                    if(fscanf(stream,"%d %s",&tmp,aux) == 2) {
                        whichfeatures[i] = tmp;
                    } else {
                        throw "features are missing";
                    }
                }
            }
        } else {                                    // it is probably an old saved classifier
            dim_f_x = 10;
            dim_f_y = 10;
            whichfeatures.resize(nb_max_features);
            fill(whichfeatures,0);
            for(int i=0;i<strlen(OUR_PREVIOUS_FEATURES);i++) {
                whichfeatures[i] = (OUR_PREVIOUS_FEATURES[i]=='1')?1:0;
            }
            printf("WARNING: bpnet file is in an old format, supported only for backward compatibility\n");
        }
        (*extractor).setFeatureSize(dim_f_x, dim_f_y);
    }

    virtual void save(FILE *stream) {
        map.save(stream);
        classifier->save(stream);
        write_checkpoint(stream,"dim_features");
        fprintf(stream, "%d\n", dim_f_x);
        fprintf(stream, "%d\n", dim_f_y);
        write_checkpoint(stream,"features");

        fprintf(stream, "%d\n", nb_max_features);
        for(int i=0;i<nb_max_features;++i) {
            fprintf(stream,"%d %s\n",whichfeatures[i],possible_feature_names[i]);
        }
    }

    virtual const char *description() {
        return "character classifier";
    }

    void printFeatures() {
        printf("{");
        for(int i=0;i<whichfeatures.dim(0);i++) {
            if(whichfeatures[i]) {
                printf(" %s", possible_feature_names[i]);
            }
        }
        printf("}\n");
        printf("{");
        for(int i=0;i<whichfeatures.dim(0);i++) {
            printf(" %d", whichfeatures[i]);
        }
        printf(" }\n");
        printf("Size : %d %d\n", dim_f_x, dim_f_y);
    }

    LineCharacterClassifier(IClassifier *c,bool garbage,const char* strfeatures,
                            int dim_x,int dim_y):
        classifier(c), extractor(make_FeatureExtractor()),
        output_garbage(garbage),
        ninput(-1),
        init(false),
        usedegrade(false),
        garbage_portion(1.)
            {
            whichfeatures.resize(strlen(strfeatures));
            fill(whichfeatures,0);
            for(int i=0;i<strlen(strfeatures);i++) {
                whichfeatures[i] = (strfeatures[i]=='1')?1:0;
            }
            dim_f_x = dim_x;
            dim_f_y = dim_y;
            (*extractor).setFeatureSize(dim_x,dim_y);
    }

    LineCharacterClassifier(const char* strfeatures, int dim_x, int dim_y):
        classifier(make_BpnetClassifier()), extractor(make_FeatureExtractor()),
        output_garbage(false),
        ninput(-1),
        init(false) {
            whichfeatures.resize(strlen(strfeatures));
            fill(whichfeatures,0);
            for(int i=0;i<strlen(strfeatures);i++) {
                whichfeatures[i] = (strfeatures[i]=='1')?1:0;
            }
            dim_f_x = dim_x;
            dim_f_y = dim_y;
            (*extractor).setFeatureSize(dim_x,dim_y);
    }

    virtual void addTrainingChar(bytearray &image,int base_y,int xheight_y,
                                        int descender_y,int ascender_y,nustring &characters) {
        if(characters.length() != 1) {
            throw "addTrainingChar cannot handle multiple characters";
        }
        floatarray features;
        extractor->setLineInfo(base_y, xheight_y - base_y);
        if(usedegrade) {
            degrade(image);
        }
        append_our_features(*extractor, features, image);
        ninput = features.length();
        int cls;
        ASSERTWARN(characters[0].ord()!=32);    // maybe it's an error to put a white space

        if(init) {
            cls = map.get_class_no_add(characters[0].ord());
            if(cls!=-1) {
                if(characters[0].ord() != GARBAGE) {
                    classifier->add(features, cls);
                } else {
                    if((rand()/float(RAND_MAX))<garbage_portion) {
                        classifier->add(features, cls);
                    }
                }
            }
        } else {
            cls = map.get_class(characters[0].ord());
            if(characters[0].ord() != GARBAGE) {
                classifier->add(features, cls);
            } else {
                if((rand()/float(RAND_MAX))<garbage_portion) {
                    classifier->add(features, cls);
                }
            }
        }
    }

    virtual void addTrainingChar(bytearray &image, nustring &characters) {
        if(characters.length() != 1) {
            throw "addTrainingChar cannot handle multiple characters";
        }
        floatarray features;
        append_our_features(*extractor, features, image, false);
        ninput = features.length();

        int cls;
        if(init) {
            cls = map.get_class_no_add(characters[0].ord());
            if(cls!=-1) {
                if(characters[0].ord() != GARBAGE) {
                    classifier->add(features, cls);
                } else {
                    if((rand()/float(RAND_MAX))<garbage_portion) {
                        classifier->add(features, cls);
                    }
                }
            }
        } else {
            cls = map.get_class(characters[0].ord());
            if(characters[0].ord() != GARBAGE) {
                classifier->add(features, cls);
            } else {
                if((rand()/float(RAND_MAX))<garbage_portion) {
                    classifier->add(features, cls);
                }
            }
        }
    }

    virtual void startTraining(const char *type) {
        classifier->start_training();
    }

    virtual void finishTraining() {
        if(!init) {
            classifier->param("ninput", ninput);
            classifier->param("noutput", map.length());
        }
        //classifier->start_classifying();
        intarray conf_matrix_train;
        intarray conf_matrix_test;
        classifier->start_classifying(conf_matrix_train,conf_matrix_test);
        autodel<ConfusionMatrix> CMtrain;
        autodel<ConfusionMatrix> CMtest;

        CMtrain = make_ConfusionMatrix(conf_matrix_train);
        CMtest = make_ConfusionMatrix(conf_matrix_test);

        logger_confusion_map_train.format("Final train confusion matrix");
        logger_confusion_map_train.confusion(*CMtrain,map);
        logger_confusion_map_test.format("Final test confusion matrix");
        logger_confusion_map_test.confusion(*CMtest,map);

        logger_conf_map_train_reduced.format("Final train reduced confusion matrix");
        logger_conf_map_train_reduced.reduced_confusion(*CMtrain,map);
        logger_conf_map_test_reduced.format("Final test reduced confusion matrix");
        logger_conf_map_test_reduced.reduced_confusion(*CMtest,map);
        init = true;
    }

    virtual void set(const char *key, double value) {
        if(strcmp(key,"garbage_portion") == 0) {
            garbage_portion = value;
        } else if(strcmp(key,"degrade") == 0) {
            usedegrade = bool(value);
        } else {
            classifier->param(key, value);
        }
    }
};

namespace ocropus {

    void train(ICharacterClassifier &classifier, ICharacterLibrary &charlib) {
        for(int i=0; i<charlib.sectionsCount();i++) {
            charlib.switchToSection(i);
            nustring c(1);
            for(int j=0;j<charlib.charactersCount();j++) {
                c[0] = nuchar(charlib.character(j).code());
                classifier.addTrainingChar(charlib.character(j).image(),
                                           charlib.character(j).baseline(),
                                           charlib.character(j).xHeight() + charlib.character(j).baseline(),
                                           charlib.character(j).descender(),
                                           charlib.character(j).ascender(),
                                           c);
            }
            printf("done with section %d\n", i);
        }
    }

    ICharacterClassifier *make_AdaptClassifier(IClassifier *c, bool garbage,
            const char* strfeatures, int dim_x, int dim_y) {
        //printf("%s %d", strfeatures, nb_max_features);
        ASSERT(strlen(strfeatures)==nb_max_features);
        return new LineCharacterClassifier(c, garbage, strfeatures, dim_x, dim_y);
    }

    ICharacterClassifier *make_BpnetCharacterClassifier(const char* strfeatures, int dim_x, int dim_y) {
        ASSERT(strlen(strfeatures)==nb_max_features);
        return new LineCharacterClassifier(strfeatures, dim_x, dim_y);
    }
}
