import unrolling, ltl, expressions, config, preproc, states, brind, ltl, copy, yicesfull
from ltbmc import LOOP_VAR, LOOPEL_VAR, OK_VAR, SMALLER_EXISTS, APPROX_SUFFIX
from names import DELTAVAR_NAME, DELTA_FRAC, DELTA_INT, OPEN, FAIR_PREFIX

HELPER_SUBEXPR_SUFFIX_1 = '-A'
HELPER_SUBEXPR_SUFFIX_2 = '-B'
HELPER_CLOCK_PREFIX = 'helper-clk-'
LEFT_CLOSED_PREFIX = 'left-closed-'

#TODO: check encoding for each operator
#TODO: CHECK LOOP END CONSTRAINT
#TODO: CHECK NON_ZENONESS
#TODO: variables like +obligation on loop

class MITLBMC(unrolling.BMC):
	FULL_XINVARIANT = True
	COMPLETE = False
	INT_FRAC_SPLIT = True
	SPECTYPE = 'MITLSPEC'
	COMBINED_TRANS = False
	INTERVAL_TRANS = True
	
	def __init__(self, variables, clockmax, special_vars, constants, initials,    \
	             _, invariants, __, transitions, prop, ___, \
	             definitions, transition_definitions, statistics):
		yicesfull.enable_type_checker(1)
		self.statevars = []
		for v in variables:
			if not v.endswith(preproc.PRIME_SUFFIX):
				self.statevars.append(v)
		variables[SMALLER_EXISTS] = expressions.BOOLEAN
		variables[LOOP_VAR] = expressions.BOOLEAN
		variables[LOOPEL_VAR] = expressions.BOOLEAN
		for clk in clockmax:
			variables[OK_VAR + clk] = expressions.BOOLEAN
		
		variables.update(special_vars)
		variables[DELTA_FRAC] = expressions.REAL
		variables[DELTA_INT] = expressions.INTEGER
		variables[OPEN + preproc.PRIME_SUFFIX] = expressions.BOOLEAN
		
		ninitial, ninvar, ntrans, nfairness, nsvars, nclockmax, comments = \
						ltl.encode_neg_mitl(prop, variables, constants, definitions)
		
		if config.yicesfile != None:
			if '.' in config.yicesfile:
				fn = config.yicesfile.split('.')
				fn = '.'.join(fn[:-2] + [fn[-2] + '-comments', fn[-1]])
			else:
				fn = config.yicesfile + '-comments'
			with open(fn, 'w') as f:
				f.write(comments)
		
		initials += ninitial
		invariants += ninvar
		transitions += ntrans
		self.statevars += nsvars
		clockmax.update(nclockmax)
		self.fairness = nfairness
		for i in xrange(len(self.fairness)):
			vn = FAIR_PREFIX + str(i)
			special_vars[vn] = expressions.BOOLEAN
			special_vars[vn + preproc.PRIME_SUFFIX] = expressions.BOOLEAN
		
		for clk in clockmax:
			variables[OK_VAR + clk] = expressions.BOOLEAN
		
		for vn in variables.keys():
			if vn.endswith(preproc.PRIME_SUFFIX):
				del variables[vn]
		
		for cn in clockmax:
			cni = cn + brind.INT_SUFFIX
			cnf = cn + brind.FRAC_SUFFIX
			if cni in variables:
				del variables[cni]
			if cnf in variables:
				del variables[cnf]
		
		for vn in special_vars:
			if vn in variables:
				del variables[vn]
		del variables[DELTA_FRAC]
		del variables[DELTA_INT]
		
		special_vars[DELTA_INT + preproc.PRIME_SUFFIX] = expressions.INTEGER
		special_vars[DELTA_FRAC + preproc.PRIME_SUFFIX] = expressions.REAL
		special_vars[DELTAVAR_NAME] = expressions.REAL
		special_vars[DELTAVAR_NAME + preproc.PRIME_SUFFIX] = expressions.REAL
		
		unrolling.BMC.__init__(self, variables, clockmax, special_vars, constants, \
					initials, None, invariants, None, transitions, prop, None,      \
					definitions, transition_definitions, statistics)
		del self.property
		del self.clockvarcount
		self.bound = config.mitl_bound

	def raw_ast_asn(self, *arg):
		self.yi.assertion(self.yi.encode(expressions.AstExpression(arg)))
	
	def assert_loop_var_permanent(self, index):
		# Nonzeno basic
		zero = expressions.Number(0)
		if index > 0:
			asns = []
			asns.append(expressions.AstExpression('<->',
					LOOPEL_VAR,
					('|', LOOPEL_VAR + preproc.PRIME_SUFFIX, OPEN)))
			for clk in self.clockmaxdict:
				const = expressions.Number(self.clockmaxdict[clk])
				asns.append(expressions.AstExpression('<->',
									OK_VAR + clk,
									('|',
										OK_VAR + clk + preproc.PRIME_SUFFIX,
										('|',
											('&',
												('=', clk + brind.INT_SUFFIX, zero),
												('=', clk + brind.FRAC_SUFFIX, zero)
											),
											('|',
												('>', clk + brind.INT_SUFFIX, const),
												('&',
													('=', clk + brind.INT_SUFFIX, const),
													('>', clk + brind.FRAC_SUFFIX, zero)
												)
											)
										)
									)
								))
			
			for i, f in enumerate(self.fairness):
				vn = FAIR_PREFIX + str(i)
				asns.append(expressions.AstExpression('<->',
								vn,
								('|', vn + preproc.PRIME_SUFFIX, f)))
			
			self.assertion(asns, index - 1)
		asns = []
		asns.append(expressions.AstExpression('->', LOOP_VAR, LOOPEL_VAR))
		for clk in self.clockmaxdict:
			asns.append(expressions.AstExpression('->', LOOP_VAR, OK_VAR + clk))
		for i in xrange(len(self.fairness)):
			asns.append(expressions.AstExpression('->', LOOP_VAR, FAIR_PREFIX + str(i)))
		
		# Loop only to zero delta state TODO: ??
		asns.append(expressions.AstExpression('->', LOOP_VAR, ('=', DELTA_INT, zero)))
		asns.append(expressions.AstExpression('->', LOOP_VAR, ('=', DELTA_FRAC, zero)))
		self.assertion(asns, index)
	
	def assert_loop_var_constraints(self, index):
		# TODO: make sure that loop is closed from singular to singular or from open to open
		fvl = self.var_low[index]
		zero = expressions.Number(0)
		# Loop  variables
		for ind in xrange(index):
			loopvar = unrolling.VAR_AT_STEP_FMT % (LOOP_VAR, ind)
			
			# State variables
			for var in self.statevars:
				vend = unrolling.VAR_AT_STEP_FMT % (var, index)
				vinner = unrolling.VAR_AT_STEP_FMT % (var, ind)
				if not var in self.special_vars:
					self.yi.assertion(self.yi.encode(expressions.AstExpression('|',
							('!', loopvar),
							('=', vend, vinner))))
			
			# Region based
			# ----------------------- Basic encoding -----------------------
			for clka in self.clockmaxdict:
				aint = clka + brind.INT_SUFFIX
				ainte = unrolling.VAR_AT_STEP_FMT % (aint, index)
				ainti = unrolling.VAR_AT_STEP_FMT % (aint, ind)
				afrac = clka + brind.FRAC_SUFFIX
				afrace = unrolling.VAR_AT_STEP_FMT % (afrac, index)
				afraci = unrolling.VAR_AT_STEP_FMT % (afrac, ind)
				amax = expressions.Number(self.clockmaxdict[clka])
				
				self.raw_ast_asn('->',
						loopvar,
						('|',
							('&', ('>', ainti, amax), ('>', ainte, amax)),
							('=', ainti, ainte)))
				
				innera = expressions.AstExpression('&', loopvar, ('<=', ainte, amax))
				
				for clkb in self.clockmaxdict:
					if clkb != clka:
						bint = clkb + brind.INT_SUFFIX
						binte = unrolling.VAR_AT_STEP_FMT % (bint, index)
						binti = unrolling.VAR_AT_STEP_FMT % (bint, ind)
						bfrac = clkb + brind.FRAC_SUFFIX
						bfrace = unrolling.VAR_AT_STEP_FMT % (bfrac, index)
						bfraci = unrolling.VAR_AT_STEP_FMT % (bfrac, ind)
						bmax = expressions.Number(self.clockmaxdict[clkb])
					
						self.raw_ast_asn('->',
								('&', innera, ('<=', binte, bmax)),
								('<->', ('<=', afraci, bfraci), ('<=', afrace, bfrace)))
				
				self.raw_ast_asn('->', innera, ('<->', ('=', afraci, zero), ('=', afrace, zero)))
		
		asns = []
		# ----------------------- Non-zeno basic -----------------------
		asns.append(expressions.AstExpression('!', LOOPEL_VAR))
		for clk in self.clockmaxdict:
			const = expressions.Number(self.clockmaxdict[clk])
			asns.append(expressions.AstExpression('!', OK_VAR + clk))
		
		for i in xrange(len(self.fairness)):
			asns.append(expressions.AstExpression('!', FAIR_PREFIX + str(i)))

		# We need loop
		asns.append(expressions.AstExpression(SMALLER_EXISTS))
		# TODO: only loop on zero delta states
		asns.append(expressions.AstExpression('=', DELTA_INT, zero))
		asns.append(expressions.AstExpression('=', DELTA_FRAC, zero))
		self.assertion(asns, index)
	
	def add_variables(self, index):
		unrolling.BMC.add_variables(self, index)
	
	def step(self, index, initial_act = None):
		unrolling.BMC.step(self, index, initial_act)
		asns = []		
				
		# SmallerExists
		if index == 0:
			self.assertion([expressions.AstExpression('=',
					SMALLER_EXISTS,
					expressions.BooleanConstant(False))], 0)
		else:
			self.assertion([expressions.AstExpression('=',
					SMALLER_EXISTS + preproc.PRIME_SUFFIX,
					('|', SMALLER_EXISTS, LOOP_VAR))], index - 1)
		
		# AtmostOne
		asns.append(expressions.AstExpression('!', ('&', SMALLER_EXISTS, LOOP_VAR)))
		
		self.assertion(asns, index)
		
		self.assert_loop_var_permanent(index)
	
	def check(self, index):
		self.yi.push()
		self.assert_loop_var_constraints(index)
		try:
			if self.yi.check():
				return False, self.get_counter_example(self.yi.get_model())
			else:
				return True, None
		finally:
			self.yi.pop()
		
	
	def get_counter_example(self, mdl):
#		import pdb; pdb.set_trace()
		ce = states.CounterExample()
		for (i, low) in enumerate(self.var_low):
			s = states.State(mdl, low, self)
			ce.append((i, s))
		
		to = 0
		for i in xrange(len(ce)):
			state = ce[i][1]
			to += 1
		return ce

if __name__ == "__main__":
	import nusmv_yacc, sys
	if len(sys.argv) <= 1:
		print 'No arguments'
	else:
		for arg in sys.argv[1:]:
			print arg
			print dir(nusmv_yacc.expression_parser)
			print

