# linear time BMC

import unrolling, ltl, expressions, config, preproc, states, brind

# Special variables:
ELAPSE_VAR = 'elapse-'
SMALLER_EXISTS = '-smaller-exists'
APPROX_SUFFIX = '-approx'
LOOPEL_VAR = '-loop-elapse'
OK_VAR = '-ok-'
LOOP_DELTA_VAR = 'loop-delta'
from names import LOOP_VAR

class LTBMC(unrolling.BMC):
	FULL_XINVARIANT = True
	COMPLETE = False
	INT_FRAC_SPLIT = False
	SPECTYPE = 'LTLSPEC'
	COMBINED_TRANS = False
	INTERVAL_TRANS = False

	def __init__(self, variables, clockmax, special_vars, constants, initials,    \
	             xinitials, invariants, xinvariants, transitions, prop, xproperty, \
	             definitions, transition_definitions, statistics):
		self.statevars = []
		for v in variables:
			if not v.endswith(preproc.PRIME_SUFFIX):
				self.statevars.append(v)
		if config.ltl_bmc_lasso:
			for clk in  clockmax:
				self.statevars.append(clk)
		variables[SMALLER_EXISTS] = expressions.BOOLEAN
		variables[LOOP_VAR] = expressions.BOOLEAN
		if not config.ltl_bmc_lasso:
			if config.ltl_bmc_mixed_type:
				for clk in clockmax:
					variables[clk + brind.INT_SUFFIX] = expressions.INTEGER
					variables[clk + brind.FRAC_SUFFIX] = expressions.REAL
			if config.ltl_bmc_one_nonzeno:
				if config.ltl_bmc_mixed_type:
					variables[LOOP_DELTA_VAR] = expressions.REAL
				else:
					variables[LOOP_DELTA_VAR + brind.INT_SUFFIX] = expressions.INTEGER
					variables[LOOP_DELTA_VAR + brind.FRAC_SUFFIX] = expressions.REAL
			else:
				variables[LOOPEL_VAR] = expressions.BOOLEAN
				for clk in clockmax:
					variables[OK_VAR + clk] = expressions.BOOLEAN
		unrolling.BMC.__init__(self, variables, clockmax, special_vars, constants, initials, \
		             xinitials, invariants, xinvariants, transitions, prop, xproperty,        \
		             definitions, transition_definitions, statistics)
		
		self.bound = config.ltl_bmc_bound
		pnf = ltl.positive_normal_form(expressions.AstExpression("!", self.property), self.variables, self.constants, self.definitions)
		self.nproperty, self.propexprs = ltl.split_to_subformulas(pnf)
		assert not isinstance(self.nproperty, expressions.AstExpression) or not self.nproperty[0] in ltl.MITLOPS
		del self.property
		self.propexpr_list = []
		for vn, expr in self.propexprs.iteritems():
			self.propexpr_list.append((vn, expr, False))
			self.variables[vn] = expressions.BOOLEAN
			self.variables[vn + preproc.PRIME_SUFFIX] = expressions.BOOLEAN
			if isinstance(expr, expressions.AstExpression) and expr[0] in ltl.MITLOPS:
				self.propexpr_list.append((vn + APPROX_SUFFIX, expr, True))
				self.variables[vn + APPROX_SUFFIX] = expressions.BOOLEAN
				self.variables[vn + APPROX_SUFFIX + preproc.PRIME_SUFFIX] = expressions.BOOLEAN
		self.propexpr_low = []
	
	def add_special_var(self, name, typ):
		self.special_vars[name] = typ
		self.variables[name] = typ
		self.ordered_vars.append(name)
		return len(self.ordered_vars) - 1
	
	def raw_ast_asn(self, *arg):
		self.yi.assertion(self.yi.encode(expressions.AstExpression(arg)))
	
	def assert_loop_var_permanent(self, index):
		if not config.ltl_bmc_lasso:
			if not config.ltl_bmc_one_nonzeno: # Nonzeno basic
				if index > 0:
					zero = expressions.Number(0)
					asns = []
					asns.append(expressions.AstExpression('<->', LOOPEL_VAR, ('|', LOOPEL_VAR + preproc.PRIME_SUFFIX, ELAPSE_VAR)))
					for clk in self.clockmaxdict:
						asns.append(expressions.AstExpression('<->', OK_VAR + clk, ('|', OK_VAR + clk + preproc.PRIME_SUFFIX, ('=', clk + brind.INT_SUFFIX, zero))))
					self.assertion(asns, index - 1)
				asns = []
				asns.append(expressions.AstExpression('->', LOOP_VAR, LOOPEL_VAR))
				for clk in self.clockmaxdict:
					asns.append(expressions.AstExpression('->', LOOP_VAR, OK_VAR + clk))
				self.assertion(asns, index)
			else: # Nonzeno one
				one = expressions.Number(1)
				if self.INT_FRAC_SPLIT:
					zero = expressions.Number(0)
					ldi = LOOP_DELTA_VAR + brind.INT_SUFFIX
					ldf = LOOP_DELTA_VAR + brind.FRAC_SUFFIX
					crct = expressions.AstExpression('?:', ('>=', ('+', ldf, brind.DELTA_FRAC), one), one, zero)
					if index > 0:
						self.assertion([
									expressions.AstExpression('=', ldi, ('+', ('+', ldi + preproc.PRIME_SUFFIX, brind.DELTA_INT), crct)),
									expressions.AstExpression('=', ldf, ('-', ('+', ldf + preproc.PRIME_SUFFIX, brind.DELTA_FRAC), crct))
								], index - 1)
					self.assertion([expressions.AstExpression('->', LOOP_VAR, ('>=', LOOP_DELTA_VAR + brind.INT_SUFFIX, one))], index)
				else:
					if index > 0:
						self.assertion([expressions.AstExpression('=', LOOP_DELTA_VAR, ('+', LOOP_DELTA_VAR + preproc.PRIME_SUFFIX, brind.DELTAVAR_NAME))], index - 1)
					self.assertion([expressions.AstExpression('->', LOOP_VAR, ('>=', LOOP_DELTA_VAR, one))], index)
	
	def assert_loop_var_constraints(self, index):
		fvl = self.var_low[index]
		zero = expressions.Number(0)
		# Loop  variables
		for ind in xrange(index):
			loopvar = unrolling.VAR_AT_STEP_FMT % (LOOP_VAR, ind)
			
			# State variables
			for var in self.statevars:
				vend = unrolling.VAR_AT_STEP_FMT % (var, index)
				vinner = unrolling.VAR_AT_STEP_FMT % (var, ind)
				if not var in self.special_vars:
					self.yi.assertion(self.yi.encode(expressions.AstExpression('|', ('!', loopvar), ('=', vend, vinner))))
			
			if config.ltl_bmc_lasso: # Lasso shaped paths
				for var in self.clockmaxdict:
					vend = unrolling.VAR_AT_STEP_FMT % (var, index)
					vinner = unrolling.VAR_AT_STEP_FMT % (var, ind)
					if not var in self.special_vars:
						self.yi.assertion(self.yi.encode(expressions.AstExpression('|', ('!', loopvar), ('=', vend, vinner))))
			else: # Region based
				# ----------------------- Basic encoding -----------------------
				for clka in self.clockmaxdict:
					aint = clka + brind.INT_SUFFIX
					ainte = unrolling.VAR_AT_STEP_FMT % (aint, index)
					ainti = unrolling.VAR_AT_STEP_FMT % (aint, ind)
					afrac = clka + brind.FRAC_SUFFIX
					afrace = unrolling.VAR_AT_STEP_FMT % (afrac, index)
					afraci = unrolling.VAR_AT_STEP_FMT % (afrac, ind)
					amax = expressions.Number(self.clockmaxdict[clka])
					
					self.raw_ast_asn('->', loopvar, ('|', ('&', ('>', ainti, amax), ('>', ainte, amax)), ('=', ainti, ainte)))
					
					innera = expressions.AstExpression('&', loopvar, ('<=', ainte, amax))
					
					for clkb in self.clockmaxdict:
						if clkb != clka:
							bint = clkb + brind.INT_SUFFIX
							binte = unrolling.VAR_AT_STEP_FMT % (bint, index)
							binti = unrolling.VAR_AT_STEP_FMT % (bint, ind)
							bfrac = clkb + brind.FRAC_SUFFIX
							bfrace = unrolling.VAR_AT_STEP_FMT % (bfrac, index)
							bfraci = unrolling.VAR_AT_STEP_FMT % (bfrac, ind)
							bmax = expressions.Number(self.clockmaxdict[clkb])
						
							self.raw_ast_asn('->', ('&', innera, ('<=', binte, bmax)), ('<->', ('<=', afraci, bfraci), ('<=', afrace, bfrace)))
					
					self.raw_ast_asn('->', innera, ('<->', ('=', afraci, zero), ('=', afrace, zero)))
				
				#TODO: Split
				
				#TODO: Limited
			
		if not config.ltl_bmc_lasso:
			asns = []
			if config.ltl_bmc_one_nonzeno:
				# ------------------------ Non-zeno one ------------------------
				if config.ltl_bmc_mixed_type:
					asns.append(expressions.AstExpression('=', LOOP_DELTA_VAR, zero))
				else:
					asns.append(expressions.AstExpression('=', LOOP_DELTA_VAR + brind.INT_SUFFIX, zero))
					asns.append(expressions.AstExpression('=', LOOP_DELTA_VAR + brind.FRAC_SUFFIX, zero))
			else:
				# ----------------------- Non-zeno basic -----------------------
				asns.append(expressions.AstExpression('!', LOOPEL_VAR))
				if config.ltl_bmc_mixed_type:
					for clk in self.clockmaxdict:
						asns.append(expressions.AstExpression('<->', OK_VAR + clk, ('|',
													('=', clk, zero), ('>', clk, expressions.Number(self.clockmaxdict[clk]))
												)))
				else:
					for clk in self.clockmaxdict:
						const = expressions.Number(self.clockmaxdict[clk])
						asns.append(expressions.AstExpression('<->', OK_VAR + clk, ('|',
													('&', ('=', clk + brind.INT_SUFFIX, zero), ('=', clk + brind.FRAC_SUFFIX, zero)),
													('|',
														('>', clk + brind.INT_SUFFIX, const),
														('&',
															('=', clk + brind.INT_SUFFIX, const),
															('>', clk + brind.FRAC_SUFFIX, zero)
														)
													)
												)))
			self.assertion(asns, index)
	
	def add_ltl_variables(self, index):
		assert len(self.propexpr_low) == index
		self.propexpr_low.append(self.yi.vnames)
		for name, _, __ in self.propexpr_list:
			self.yi.add_var(unrolling.VAR_AT_STEP_FMT % (name, index), expressions.BOOLEAN)
	
	def add_variables(self, index):
		if index == 0:
			self.add_ltl_variables(0)
		self.add_ltl_variables(index + 1)
		unrolling.BMC.add_variables(self, index)
		
		# Int-frac split
		if not config.ltl_bmc_lasso:
			asns = []
			zero = expressions.Number(0)
			one = expressions.Number(1)
			for clk in self.clockmaxdict:
				frac = clk + brind.FRAC_SUFFIX
				if config.ltl_bmc_mixed_type:
					asns.append(expressions.AstExpression('>=', frac, zero))
					asns.append(expressions.AstExpression('<', frac,  one ))
					asns.append(expressions.AstExpression('=', clk, ('+', clk + brind.INT_SUFFIX, frac)))
			self.assertion(asns, index)
	
	def loop_end_constraints(self, index):
		indasns = []
		false = expressions.BooleanConstant(False)
		true = expressions.BooleanConstant(True)
		for i, (name, expr, approx) in enumerate(self.propexpr_list):
			assert not (isinstance(expr, expressions.AstExpression) and expr[0] in ltl.MITLOPS) \
						or approx \
						or self.propexpr_list[i+1][0] == name + APPROX_SUFFIX
			if approx:
				if expr[0] in ['F', 'U', 'GF']:
					rhs = false
				elif expr[0] in ['G', 'V', 'FG']:
					rhs = true
				else:
					assert False
			else:
				if isinstance(expr, expressions.AstExpression) and expr[0] in ltl.MITLOPS:
					apprvar = self.propexpr_list[i+1][0]
				else:
					apprvar = name
				els = []
				for sind in xrange(index):
					els.append(expressions.AstExpression('&',
								unrolling.VAR_AT_STEP_FMT % (LOOP_VAR, sind),
								unrolling.VAR_AT_STEP_FMT % (apprvar, sind)));
				rhs = expressions.reduce_to_ast_commutative('|', els, false)
			indasns.append(expressions.AstExpression('=', unrolling.VAR_AT_STEP_FMT % (name, index + 1), rhs))
		for expr in indasns:
			self.yi.assertion(self.yi.encode(expr))
	
	def step(self, index, initial_act = None):
		unrolling.BMC.step(self, index, initial_act)
		
		# Property
		pel = self.propexpr_low[index]
		asns = []
		for name, expr, approx in self.propexpr_list:
			if isinstance(expr, expressions.AstExpression) and expr[0] in ltl.MITLOPS:
				if expr[0] == 'F' or (expr[0] == 'GF' and approx):
					asns.append(expressions.AstExpression('=', name, ('|', expr[1], name + preproc.PRIME_SUFFIX)))
				elif expr[0] == 'G' or (expr[0] == 'FG' and approx):
					asns.append(expressions.AstExpression('=', name, ('&', expr[1], name + preproc.PRIME_SUFFIX)))
				elif expr[0] == 'U':
					asns.append(expressions.AstExpression('=', name, ('|', expr[2], ('&', expr[1], name + preproc.PRIME_SUFFIX))))
				elif expr[0] == 'V':
					asns.append(expressions.AstExpression('=', name, ('&', expr[2], ('|', expr[1], name + preproc.PRIME_SUFFIX))))
				elif expr[0] in ['GF', 'FG']:
					assert not approx
					asns.append(expressions.AstExpression('=', name, name + preproc.PRIME_SUFFIX))
				else:
					assert False
			else:
				asns.append(expressions.AstExpression('=', name, expr))
		if index == 0:
			asns.append(self.nproperty)
		
		# SmallerExists
		if index == 0:
			self.assertion([expressions.AstExpression('=', SMALLER_EXISTS, expressions.BooleanConstant(False))], 0)
		else:
			self.assertion([expressions.AstExpression('=', SMALLER_EXISTS + preproc.PRIME_SUFFIX, ('|', SMALLER_EXISTS, LOOP_VAR))], index - 1)
		
		# AtmostOne
		asns.append(expressions.AstExpression('!', ('&', SMALLER_EXISTS, LOOP_VAR)))
		
		self.assertion(asns, index)
		
		self.assert_loop_var_permanent(index)
	
	def check(self, index):
		self.yi.push()
		self.loop_end_constraints(index)
		self.assert_loop_var_constraints(index)
		#TODO: closing
		try:
			if self.yi.check():
				return False, self.get_counter_example(self.yi.get_model())
			else:
				return True, None
		finally:
			self.yi.pop()
		
	
	def get_counter_example(self, mdl):
		ce = states.CounterExample()
		for (i, low) in enumerate(self.var_low):
			s = states.State(mdl, low, self)
			ce.append((i, s))
		return ce

