import math
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def y(x1, x2, a):
    return a[0] + a[1]*x1 + a[2]*x2 + a[3]*x1*x2 + a[4]*x1**2 + a[5]*x2**2

def plot_surface_and_trajectory(a, steps):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the trajectory first
    steps = np.array(steps)
    ax.plot(steps[:, 0], steps[:, 1], steps[:, 2], marker='o', color='red', linestyle='-')

    # Generate grid points for surface plot
    x1_range = np.linspace(-10, 10, 100)
    x2_range = np.linspace(-10, 10, 100)
    X1, X2 = np.meshgrid(x1_range, x2_range)
    Q = y(X1, X2, a)

    # Plot the surface
    surf = ax.plot_surface(X1, X2, Q, cmap='viridis', alpha=0.5)

    # Set axis limits
    ax.set_xlim([-10, 10])
    ax.set_ylim([-10, 10])
    ax.set_zlim([-100, 100])  # Adjust as needed

    ax.set_xlabel('X1')
    ax.set_ylabel('X2')
    ax.set_zlabel('Q(X1, X2)')
    ax.set_title('Gradient Descent Trajectory and Target Function')

    plt.show()

def gradient_descent(a, control, eps, delta_x, x10, x20, h):
    x1, x2 = x10, x20
    mod_g = 1e9
    q = 1e9
    steps = []

    while mod_g > eps and q != y(x1, x2, a):
        q = y(x1, x2, a)
        dqx1 = (y(x1 + delta_x, x2, a) - q) / delta_x
        dqx2 = (y(x1, x2 + delta_x, a) - q) / delta_x
        mod_g = math.sqrt(dqx1**2 + dqx2**2)

        steps.append([x1, x2, q, dqx1, dqx2, mod_g*h])

        print(f'For working point P: X1={x1:.4f}; X2={x2:.4f}; Q(X1,X2)={q:.4f}; '
              f'dQx1/dx1={dqx1:.5f}; dQx1/dx2={dqx2:.5f}; '
              f'Magnitude of Q function gradient: Mod(G)={mod_g:.4f}; '
              f'Step size H*Mod(G): {h*mod_g:.5f}.')

        x1 += control * h * dqx1
        x2 += control * h * dqx2

    if mod_g <= eps:
        print(f'Working point P represents the solution.')
        print(f'X1={x1:.3f}; X2={x2:.3f}; Q(X1,X2)={q:.4f}; '
              f'Magnitude of Q function gradient: Mod(G)={mod_g:.4f};')
        print(f'Cost of search (number of Q function evaluations): {len(steps)}')

        x1opt = (a[2]*a[3] - 2*a[1]*a[5]) / (4*a[4]*a[5] - a[3]**2)
        x2opt = (a[1]*a[3] - 2*a[2]*a[4]) / (4*a[4]*a[5] - a[3]**2)
        qopt = y(x1opt, x2opt, a)
        print(f'Exact solution (analytical): X1opt={x1opt:.4f} X2opt={x2opt:.4f} Qopt={qopt:.4f}')

        plot_surface_and_trajectory(a, steps)
    else:
        print('Special case! The value of the target function at the current point '
              'coincides with its value at the previous point. If the gradient magnitudes '
              'also coincide, looping of the process is possible. '
              'It is advisable to repeat the search using another starting point.')

# Example of how to call the function with parameters
a = [1, 2, 3, 4, 5, 6]  # Coefficients A0-A5
control = -1  # Max: 1 / Min: -1
eps = 0.01  # Acceptable gradient deviation
delta_x = 0.01  # Displacement magnitude for partial derivatives
x10, x20 = 0, 1  # Initial coordinates
h = 0.01  # Coefficient linking movement to gradient magnitude
gradient_descent(a, control, eps, delta_x, x10, x20, h)
