import random
import math
import mymath
import os


class ModInteger:
    """An integer mod p"""

    def __init__(self, p, value):
        self.p = p # the associated multiplicative group
        self.value = value % p
        if self.value<0:
            self.value += p

    def multiplicative_identity(self):
        return ModInteger(self.p,1)

    def __mul__(self, other):
        new_value = self.value*other.value
        return ModInteger(self.p,new_value)

    def __add__(self, other):
        new_value = self.value+other.value
        return ModInteger(self.p,new_value)

    def __sub__(self, other):
        new_value = self.value-other.value
        return ModInteger(self.p,new_value)

    def __neg__(self):
        return ModInteger(self.p,-self.value)

    def __truediv__(self, y):
        '''Modular division, can be avoided'''
        g, a, b = mymath.euclidean_algorithm(y.value, self.p)
        return self * ModInteger(self.p,a)

    def __str__(self):
        return str(self.value) + " mod " + str(self.p)

    def __eq__(self, other):
        return other.value==self.value and other.p == self.p


def exponentiate(g, exponent):
    """Raise a multiplicative group element to the given power"""
    result = g.multiplicative_identity()
    base = g
    while exponent > 0:
        if exponent % 2 == 1:
            result *= base
        exponent = exponent >> 1  # bitwise shift left
        base = g*g
    return result


class PoorRandom:
    """A pseudo random number generator, awful for cryptography"""

    def generate(self, n_bits):
        """Generate a random integer with the given number of bits"""
        return random.randint(0,2**n_bits-1)


class BetterRandom:
    """This is securely random"""

    def generate(self, n_bits):
        n_bytes = int(math.ceil(n_bits/8))
        random_bytes = os.urandom(n_bytes)
        int_value = int.from_bytes( random_bytes, byteorder='big')
        return int_value % (2**n_bits)


class DiffieHelmanExchanger:

    def __init__( self, base, random=BetterRandom(), n_bits=512 ):
        self.__secret = random.generate( n_bits )+1
        self.__public = exponentiate( base, self.__secret )

    def shared_secret(self, other_public ):
        return exponentiate( other_public, self.__secret )

    @property
    def public(self):
        return self.__public


class EllipticCurve:
    """An elliptic curve and its associated group"""

    def __init__(self, a, b, n):
        assert n!=2
        assert n!=3
        self.n = n
        a = self.to_ring_element(a)
        b = self.to_ring_element(b)
        self.a = a
        self.b = b
        self.n = n
        self.discriminant = -self.to_ring_element(16) * \
            (self.to_ring_element(4) * a * a * a + self.to_ring_element(27) * b * b)
        assert self.discriminant!=self.to_ring_element(0)  , 'Curve is not smooth'

    def to_ring_element(self, x):
        return ModInteger(self.n,x)

    def contains(self, x, y):
        return y*y  == x*x*x + self.a * x + self.b


class EllipticCurvePoint:
    """A point on an elliptic curve"""

    def __init__(self, curve, x=0, y=0, point_at_infinity=False):
        self.curve = curve
        self.x = curve.to_ring_element(x)
        self.y = curve.to_ring_element(y)
        self.point_at_infinity = point_at_infinity
        if not point_at_infinity:
            assert curve.contains(self.x,self.y)

    def multiplicative_identity(self):
        return EllipticCurvePoint(self.curve,point_at_infinity=True)

    def __eq__(self, other):
        if self.point_at_infinity:
            return other.point_at_infinity
        else:
            return self.x == other.x and self.y == other.y

    def __str__(self):
        return "(" + str(self.x.value) + "," + str(self.y.value) + ")"

    def __mul__(self, other):
        """This isn't computationally the most efficient multiplication method as it
            involves division"""
        if self.point_at_infinity:
            return other
        if other.point_at_infinity:
            return self

        x_1, y_1, x_2, y_2 = self.x, self.y, other.x, other.y

        if (x_1, y_1) == (x_2, y_2):
            if y_1 == self.curve.to_ring_element(0):
                return EllipticCurvePoint(self.curve,point_at_infinity=True)

            # slope of the tangent line
            m = (self.curve.to_ring_element(3) * x_1 * x_1 + self.curve.a) / \
                    (self.curve.to_ring_element(2) * y_1)
        else:
            if x_1 == x_2:
                return EllipticCurvePoint(self.curve,point_at_infinity=True)

            # slope of the secant line
            m = (y_2 - y_1)/(x_2 - x_1)

        x_3 = m * m - x_2 - x_1
        y_3 = m * (x_3 - x_1) + y_1

        return EllipticCurvePoint(self.curve, x=x_3.value, y=-y_3.value)







