#!/usr/bin/env python
# 
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
# 
# Written (W) 2007 Gunnar Raetsch
# Written (W) 2006-2008 Soeren Sonnenburg
# Copyright (C) 2006-2008 Fraunhofer Institute FIRST and Max-Planck-Society
# 

try:
	import os
	import os.path
	import sys
	import pickle
	import bz2
	import numpy
	import optparse

	import genomic
	import model
	import seqdict
	import shogun.Structure

	d=shogun.Structure.DynProg()
	if (d.version.get_version_revision() < 2997):
		print
		print "ERROR: SHOGUN VERSION 0.6.2 or later required"
		print
		sys.exit(1)
	from content_sensors import content_sensors
	from signal_detectors import signal_detectors
	from plif import plif
except ImportError:
	print
	print "ERROR IMPORTING MODULES, MAKE SURE YOU HAVE SHOGUN INSTALLED"
	print
	sys.exit(1)


msplicer_version='v0.3'

class msplicer:
	def __init__(self):
		self.model = None
		self.plif = None
		self.signal = None
		self.content = None
		self.model_name = None

	def load_model(self, filename):
		self.model_name = filename
		sys.stderr.write('loading model file\n')
		f=None
		picklefile=filename+'.pickle'
		if os.path.isfile(picklefile):
			self.model=pickle.load(file(picklefile))	
		else:
			if filename.endswith('.bz2'):
				f=bz2.BZ2File(filename);
			else:
				f=file(filename);

			self.model=model.parse_file(f)
			f.close()

			f=file(picklefile,'w')
			pickle.dump(self.model, f)
			f.close()

		self.plif=plif(self.model)
		self.signal=signal_detectors(self.model)
		self.content=content_sensors(self.model)

	def compute_seqmatrix(self, seq):
		# start-state: 0
		# exon-start-state: 1
		# donor-state: 2
		# acceptor-state: 3
		# exon-end-state: 4
		# stop-state: 5

		start_idx = numpy.where(self.model.statedescr == 0)[0]
		exon_start_idx = numpy.where(self.model.statedescr == 1)[0]
		don_idx = numpy.where(self.model.statedescr == 2)[0]
		acc_idx = numpy.where(self.model.statedescr == 3)[0]
		exon_stop_idx = numpy.where(self.model.statedescr == 4)[0]
		stop_idx = numpy.where(self.model.statedescr == 5)[0]

		# start positions
		positions=[(0,0,start_idx)]
		positions.append((seq.start,0,exon_start_idx))

		# end positions
		positions.append((seq.end, 0, exon_stop_idx[0]))
		if len(exon_stop_idx)>1:
			idx = numpy.where(numpy.array(seq.preds['acceptor'].positions,numpy.int32)==seq.end)[0]
			if len(idx)==1:
				positions.append((seq.end, seq.preds['acceptor'].scores[idx], exon_stop_idx[1]))
		positions.append((len(seq.seq)-1,0,stop_idx))

		# donor posititions
		for i in don_idx:
			positions.extend(zip(seq.preds['donor'].positions,
								 seq.preds['donor'].scores,
								 len(seq.preds['donor'].positions)*[i]))

		# acceptor positions
		for i in acc_idx:
			positions.extend(zip(seq.preds['acceptor'].positions,
								 list(seq.preds['acceptor'].scores),
								 len(seq.preds['acceptor'].positions)*[i]))

		positions.sort(cmp=lambda x,y : int(x[0]-y[0]))
		unique_positions= numpy.unique(numpy.array([ x[0] for x in positions ], numpy.int32))

		seqmatrix= -numpy.infty * numpy.ones((len(self.model.statedescr),len(unique_positions)))
		for i in xrange(len(positions)):
			p = numpy.where(positions[i][0]==unique_positions)[0] ;
			assert(len(p)==1)
			p = p[0] ;
			seqmatrix[positions[i][2],p]=positions[i][1]

		if len(don_idx)>1: # orf case
			for i in xrange(len(unique_positions)):
				if seqmatrix[don_idx[0], i] > -1e20:
					s1 = seq.seq[unique_positions[i]-1:unique_positions[i]+1]
					s2 = seq.seq[unique_positions[i]-2:unique_positions[i]+1]
					if s1 in ['TG']: seqmatrix[don_idx[1], i]=-numpy.infty 
					if s1 not in ['TG']: seqmatrix[don_idx[2], i]=-numpy.infty
					if s2 in ['TAG', 'TGG']: seqmatrix[don_idx[3], i]=-numpy.infty
					if s2 not in ['TAG']: seqmatrix[don_idx[4], i]=-numpy.infty
					if s2 not in ['TGG']: seqmatrix[don_idx[5], i]=-numpy.infty

		if len(acc_idx)>1: # orf case
			for i in xrange(len(unique_positions)):
				if seqmatrix[acc_idx[0], i] > -1e20:
					s1 = seq.seq[unique_positions[i]-1:unique_positions[i]+1]
					s2 = seq.seq[unique_positions[i]-1:unique_positions[i]+2]
					if s2 in ['GAA', 'GAG', 'GGA']: seqmatrix[acc_idx[2], i]=-numpy.infty
					if s1 in ['GA', 'GG']: seqmatrix[acc_idx[4], i]=-numpy.infty
					if s1 in ['GA']: seqmatrix[acc_idx[5], i]=-numpy.infty
			
		plifstatemat = -numpy.ones((len(self.model.statedescr),1), numpy.int32);
		plifstatemat[acc_idx,0] = 0 ; # acceptors use first plif
		plifstatemat[don_idx,0] = 1 ; # donors use second plif
		
		return (seqmatrix, unique_positions, plifstatemat)


	def initialize_dynprog(self, seq):
		dyn=shogun.Structure.DynProg()

		self.content.initialize_content(dyn)
		
		n=len(self.model.p)
		dyn.set_num_states(n)
		dyn.set_p_vector(self.model.p)
		dyn.set_q_vector(self.model.q)
		dyn.set_a_trans_matrix(self.model.a_trans)

		#design scoring seqmatrix
		(seqmatrix, positions, plifstatemat) = self.compute_seqmatrix(seq)

		dyn.best_path_set_seq(seqmatrix)
		dyn.best_path_set_pos(positions)
		dyn.best_path_set_orf_info(self.model.orf_info)
	
		dyn.best_path_set_plif_list(self.plif.get_plif_array())

		dyn.best_path_set_plif_id_matrix(self.model.plifidmat.T)
		dyn.best_path_set_plif_state_signal_matrix(plifstatemat)
		s=[]; s+=seq.seq;
		dyn.best_path_set_single_genestr(numpy.array(s))
		dyn.best_path_set_dict_weights(self.content.get_dict_weights())

	#	self.precompute_content_svm_values(self, dyn, seq, positions)

		return (dyn,positions)

	#def precompute_content_svm_values(self, dyn, seq, positions):
	#	wordstr=dyn.create_word_string(seq, 1, len(seq));
	#	dyn.init_content_svm_value_array(Npos)
	#	weights = self.content.get_dict_weights()
	#	#n = size(weights, 1)
	#	#m = size(weights, 2)
	#	dyn.precompute_content_values(wordstr, positions, len(positions), len(seq), self.content.get_dict_weights(), n*m);
	#	dyn.set_genestr_len(len(seq));
	#	return (dyn)

	def write_gff(self, outfile, pred, name, score, skipheader):
		descr=list()
		for i in xrange(pred.shape[0]):
			d=dict()
			d['seqname']=name
			d['source']='msplicer'
			d['feature']='exon'
			d['start']=pred[i,0]+1
			d['end']=pred[i,1]
			d['score']=score
			d['strand']='+'
			d['frame']=0
			descr.append(d)

		genomic.write_gff(outfile, ('msplicer',msplicer_version + ' ' + self.model_name),
				('DNA', name), descr, skipheader)

	def predict_file(self, fname, (start,end)):
		skipheader=False
		fasta_dict = genomic.read_fasta(file(fname))
		sys.stderr.write('found fasta file with ' + `len(fasta_dict)` + ' sequence(s)\n')
		seqs= seqdict.seqdict(fasta_dict, (start,end))

		#get donor/acceptor signal predictions for all sequences
		self.signal.predict_acceptor_sites_from_seqdict(seqs)
		self.signal.predict_donor_sites_from_seqdict(seqs)

		for seq in seqs:
			#initialize dynamic programming, with content sensors
			#signal detectors, Plifs and HMM like model
			(dyn,positions)=self.initialize_dynprog(seq)

			#compute max likely path
			dyn.best_path_call(1, self.model.use_orf)
			scores=dyn.best_path_get_scores()
			states=dyn.best_path_get_states()
			pos=dyn.best_path_get_positions()
			pred_states=states[0][0:numpy.where(pos[0]==-1)[0]][1:-1]
			pred=positions[pos[0][0:numpy.where(pos[0]==-1)[0]][1:-1]]
			#print scores
			#print pred_states
			#print pred
			#print len(pred_states)
			if (len(pred_states)>0):
				if (pred_states[-1]==15): # joint state for acceptor and stop codon
					pred_ = numpy.zeros(len(pred)+1, numpy.int32) ;
					pred_[0:len(pred)] = pred ;
					pred_[-1] = pred[-1]
					pred = pred_ 
			
			pred=pred.reshape((len(pred)/2,2))
			self.write_gff(outfile, pred, seq.name, scores, skipheader)
			skipheader=True

			if 0:
				my_posi = numpy.array([  1, 400, 408, 451, 1188, 1785, 1858, 2732, 2924, 3869, 3948, 4348 ], numpy.int32)-1 ;
				my_pos = numpy.zeros(len(my_posi), numpy.int32) ;
				print positions, my_posi
				for i in xrange(len(my_posi)):
					my_pos[i] = numpy.where(positions == my_posi[i])[0]
					
				my_states = numpy.array([0, 13, 6, 12, 2, 8, 4, 10, 4, 10, 14, 16], numpy.int32)
				#my_pos = numpy.array([  0, 51, 169, 204, 216, 241, 300, 355, 360, 397], numpy.int32) ;
				#my_states = numpy.array([0, 3, 1, 2, 1, 2, 1, 2, 4, 5], numpy.int32)

				my_states = states[0][0:numpy.where(pos[0]==-1)[0]] 
				my_pos    = pos[0][0:numpy.where(pos[0]==-1)[0]]

				print my_states
				print my_pos
				print positions[my_pos]
				
				dyn.best_path_set_my_state_seq(my_states)
				dyn.best_path_set_my_pos_seq(my_pos)
				
				dyn.io.set_loglevel(shogun.Structure.M_DEBUG)
				dyn.best_path_deriv_call()

def print_version():
	sys.stderr.write('mSplicer '+msplicer_version+'\n')

def parse_options():
	parser = optparse.OptionParser(usage="usage: %prog [options] seq.fa")
	
	parser.add_option("-o", "--outfile", type="str", default='stdout',
			                  help="File to write the results to")
	parser.add_option("-v", "--version", default=False,
			                  help="Show some more information")
	parser.add_option("--start", type="int", default=499,
			                  help="coding start (zero based, relative to sequence start)")
	parser.add_option("--stop", type="int", default=-499,
			                  help="""coding stop (zero based, if positive relative to
							  sequence start, if negative relative to sequence end)""")
	parser.add_option("--model", type="str", default='WS160',
			                  help="mSplicer Model to use in predicting")

	(options, args) = parser.parse_args()
	if options.version:
		print_version()
		sys.exit(0)

	if len(args) != 1:
		parser.error("incorrect number of arguments")

	fafname=args[0]
	if not os.path.isfile(fafname):
		parser.error("fasta file does not exist")

	if options.model.endswith('gc'):
		gc=1
		model=options.model[:-2]
	else:
		gc=0
		model=options.model

	if model.startswith('orf'):
		orf=1
		model=model[3:]
	else:
		orf=0

	modelfname = 'data/msplicer_elegans%s_gc=%d_orf=%d.dat.bz2' % (model, gc, orf)
	print "loading model file " + modelfname,
	
	if not os.path.isfile(modelfname):
		print "...not found!\n"
		parser.error("""model should be one of:

WS120, WS120gc, orfWS120, WS150,
WS160, WS160gc, orfWS160gc
""")

	if options.outfile == 'stdout':
		outfile=sys.stdout
	else:
		try:
			outfile=file(options.outfile,'w')
		except IOError:
			parser.error("could not open %s for writing" % options.outfile)

	if options.start<80:
		parser.error("--start value must be >=80")

	if options.stop > 0 and options.start >= options.stop - 80:
		parser.error("--stop value must be > start + 80")

	if options.stop < 0 and options.stop > -80:
		parser.error("--stop value must be <= - 80")

	# shift the start and stop a bit 
	options.start -= 1 ;
	options.stop -= 1 ;
	
	return ((options.start,options.stop), fafname, modelfname, outfile)


if __name__ == '__main__':
	dyn=shogun.Structure.DynProg()
	(startstop, fafname, modelfname, outfile ) = parse_options()
	p=msplicer()
	p.load_model(modelfname);
	p.predict_file(fafname, startstop)
