import time
import boto3
import logging
import decimal
import json

from botocore.exceptions import ClientError
from token_bucket import StorageBase

logger = logging.getLogger(__name__)


class DecimalEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, decimal.Decimal):
            if o % 1 > 0:
                return float(o)
            else:
                return int(o)
        return super(DecimalEncoder, self).default(o)


class DynamoDBStorage(StorageBase):
    ddb_client = boto3.client('dynamodb')

    def __init__(self, table_name=None):
        self.ddb_table = self.ddb_client.Table(table_name)

    def _get_token_from_ddb(self, key) -> dict:  #type: ignore
        logger.debug("get_current_count")
        try:
            response = self.ddb_table.get_item(Key={'ID': key})
            logger.debug(response)
            if 'Item' not in response:
                # must be first time, initialize
                # TODO: is not safe for concurrent atm
                init_item = {
                    'ID': key,
                    'tokens': 0,
                    'last_replenished': int(time.time())
                }
                response = self.ddb_table.put_item(Item=init_item)
                logger.debug(response)
                return init_item
            else:
                return response['Item']
        except ClientError as e:
            logger.error(e.response['Error']['Message'])

    def get_token_count(self, key) -> float:
        return self._get_token_from_ddb(key)['tokens']

    def replenish(self, key, _, capacity) -> bool:
        ddb_last_replenished = self._get_token_from_ddb(
            key)['last_replenished']
        current_time = int(time.time())
        if current_time > ddb_last_replenished:
            self.ddb_table.update_item(
                Key={'ID': key},
                UpdateExpression=
                f"SET tokens = tokens - :num_tokens, last_replenished = :current_time",
                ExpressionAttributeValues={
                    ':num_tokens': capacity,
                    ':current_time': current_time
                },
                ReturnValues="UPDATED_NEW")
            return True
        return False

    def consume(self, key, num_tokens) -> bool:
        current_count = self.get_token_count(key)
        response = self.ddb_table.update_item(
            Key={'ID': key},
            UpdateExpression=f"SET tokens = tokens - :num_tokens",
            ConditionExpression=f"(tokens = :current_count) and (tokens > 0)",
            ExpressionAttributeValues={
                ':num_tokens': num_tokens,
                ':current_count': current_count
            },
            ReturnValues="UPDATED_NEW")
        logger.debug(f"consume response: {response}")
        return True
