"""
Catch game.

"""

import numpy
import logging

class Ball(object):
    def __init__(self,
                 maxAngle,
                 numRows,
                 numCols,
                 prevBall,
                 minGapBall,
                 maxGapBall,
                 negRewards,
                 rng):
        """
        Ball object.

        Keeps track of the position (and dynamics) of a given ball that
        the agent has to catch.

        """
        self.rng = rng

        orientation = numpy.sign(self.rng.rand() - .5)
        angle       = numpy.floor(self.rng.rand()*maxAngle)
        self.dx     = numpy.tan(numpy.pi*angle/180)
        self.dy     = numpy.float32(1)

        self.x = numpy.int32(rng.rand()*numCols)
        self.y = numpy.int32(0)

        self.xf = numpy.float32(self.x)
        self.yf = numpy.float32(self.y)

        self.numRows    = numRows
        self.numCols    = numCols
        self.negRewards = negRewards

        self._cached = False
        if prevBall == -1:
            self.startAge = 0
        else:
            offset        = numpy.int32(self.rng.rand()*(maxGapBall-minGapBall))
            self.startAge =  prevBall + offset + minGapBall



    def step(self, age, paddle, paddleW):
        """
        Increment the position of the ball
        """
        if age > self.startAge:
            self.xf += self.dx
            self.yf += self.dy

            self.x = numpy.round(self.xf)
            self.y = numpy.round(self.yf)
            if self.x == self.numRows -1 or self.x == 0:
                self.dx = -self.dx

        return not self.bottom(), self.catched(paddle, paddleW)

    def bottom(self):
        """
        Did the ball reached the bottom of the screen !?
        """
        return self.y == self.numRows-1

    def catched(self, paddle, paddleW):
        """
        Did the agent catch the ball or not (decide reward)
        """
        if not self._cached:
            if (self.y == self.numRows-1 and
                (self.x >= paddle and self.x < paddle + paddleW)):
                self._cached = True
                return 1
            elif self.y == self.numRows-1:
                self._cached = True
                if self.negRewards:
                    return -1
                else:
                    return 0
            else:
                return 0
        else:
            return 0

    def draw(self, memBuffer):
        """
        Draw the ball in the current screen ...
        """
        if self.y < self.numRows:
            memBuffer[self.y, self.x] = 1

class Catch(object):
    """
    Catch game
    """

    def __init__(self,
                 numRows     = 2,
                 numCols     = 2,
                 maxAngle    = 0,
                 paddleW     = 1,
                 negRewards  = True,
                 numBalls    = 1,
                 minGapBalls = 5,
                 maxGapBalls = 12,
                 noise       = 0,
                 seed        = 123):
        """
        Screen size is numRows * numCols

        Balls fall at an angle between 0 and maxAngle (in degrees). Set to 0
        for simple version of the game.

        The paddle has a width defined by paddleW

        negRewards determine if you get a -1 if you did not catch the paddle

        numBalls shows how many balls per episode
        """
        self.rng         = numpy.random.RandomState(seed)
        self.numRows     = numRows
        self.numCols     = numCols
        self.maxAngle    = maxAngle
        self.paddleW     = paddleW
        self.seed        = seed
        self.negRewards  = negRewards
        self.age         = 0
        self.noise       = noise
        self.paddle      = numCols // 2 - self.paddleW//2
        self.actions     = [-1,0,1]
        self.numBalls    = numBalls
        self.balls       = []
        self.minGapBalls = minGapBalls
        self.maxGapBalls = maxGapBalls
        prevBall = -1
        for ball_id in xrange(self.numBalls):
            self.balls.append(Ball(self.maxAngle,
                                   self.numRows,
                                   self.numCols,
                                   prevBall,
                                   self.minGapBalls,
                                   self.maxGapBalls,
                                   self.negRewards,
                                   self.rng))
            prevBall = self.balls[-1].startAge




    def reset(self):
        self._resetState()
        obs = self._getObservation()
        reward = 0
        self.terminal = 0
        self.age = 0
        return obs


    def step(self, action):

        if self.terminal == 1:
            raise Exception("You need to call reset!")
        reward  = 0
        self.age += 1
        # move the agent based on the action
        newPos = self.paddle + self.actions[action % 3]
        # bound the agent position
        maxPos = self.numCols - self.paddleW
        self.paddle = min(max(0, newPos), maxPos)
        new_balls = []
        active_balls = 0
        for ball in self.balls:
            active, loc_reward = ball.step(self.age, self.paddle,
                                           self.paddleW)
            reward += loc_reward
            if active:
                active_balls += 1

        obs = self._getObservation()
        if active_balls > 0:
            self.terminal = 0
        else:
            self.terminal = 1
        return obs, reward, self.terminal

    def _getObservation(self):
        screen = numpy.zeros((self.numRows, self.numCols), dtype='float32')
        if self.noise > 0:
            screen = screen + self.rng.binomial(n=1, p = noise,
                                                size=(self.numRows,
                                                      self.numCols))
        # Draw balls
        for ball in self.balls:
            ball.draw(screen)
        # Draw paddle
        for dx in xrange(self.paddleW):
            screen[self.numRows-1, self.paddle + dx] = 1
        return screen

    def _resetState(self):
        self.paddle     = self.numCols // 2 - self.paddleW//2
        self.balls      = []
        prevBall = -1
        for ball_id in xrange(self.numBalls):
            self.balls.append(Ball(self.maxAngle,
                                   self.numRows,
                                   self.numCols,
                                   prevBall,
                                   self.minGapBalls,
                                   self.maxGapBalls,
                                   self.negRewards,
                                   self.rng))
            prevBall = self.balls[-1].startAge


