import pandas as pd
import numpy as np
from astropy import units as u
from astropy.time import Time, TimeDelta
from astropy.coordinates import EarthLocation, GCRS, ITRS, SkyCoord, CartesianRepresentation, get_sun
from astropy.utils.iers import conf as iers_conf
from scipy.integrate import solve_ivp
from scipy.spatial import KDTree
from poliastro.bodies import Earth
from poliastro.twobody import Orbit

"""

This is Directed Space's original open-source GNC simulation. To those readers who are technically inclined,
feel free to make any modifications and see what numbers come out.

To run:
- Install the various pip packages (we used poliastro/astropy to simplify OrbMech calcs).
  We suggest you use a venv, these packages are not standard FWIW.

- Make sure that solar_sites.csv is in the same directory as this file and has latitude,
  longitude, name, and capacity columns at the very least. 

- Run the commands in the following order:
    > python simulation.py // Creates satellite .csvs, can take a while.
    > python plot.py // Renders the plot in the browser.

- To get a summary of various things, use 
    > python analyze_farm_power.py

"""

iers_conf.auto_download = True

# --- Constants ---
N_SATS          = 1
EARTH_RADIUS_KM = 6371
epoch           = Time("2027-01-01 00:00:00", scale="utc")
STEP_SEC        = 60                        # integration step (s) → 1 minute # run in 1 minute steps for 24 hours
N_STEPS         = 24 * 60                   # 24 hours @ 1 min resolution
LOOKAHEAD       = 1                         # 1×60 s look-ahead # delay decisions by 1 minute (same look‐ahead as before, now 1×60 s)
WRITE_INTERVAL  = 60                        # 60 steps × 60 s = 3600 s = 1 hour # write out a CSV every hour
A_MIRROR_M2     = 2500                      # mirror area (m²)
ALPHA_MAX       = np.deg2rad(0.2)           # 0.2°/s² → rad/s²
RE              = 0.96                      # Reflectivity
AINT            = 0.77                      # Atmospheric Interference


# -------------- Eclipse Checking Fn -----------------------
def is_in_earth_shadow_debug(sat_pos_km, sun_pos_km, threshold_angle_deg=80.0, sat_id=None, time=None, shadow_log=None):
    """
    Logs angle between satellite-to-Earth and satellite-to-Sun vectors.

    Parameters:
    - sat_pos_km: satellite position [x, y, z] in ITRS
    - sun_pos_km: sun position [x, y, z] in ITRS
    - threshold_angle_deg: threshold above which satellite is considered in sunlight
    - sat_id: optional satellite ID for debug
    - time: optional time for logging

    Returns:
    - True if satellite is in Earth's shadow
    """

    sat_vec = np.array(sat_pos_km)
    sun_vec = np.array(sun_pos_km)

    sun_vec[2] = -sun_vec[2]

    # Vector from satellite to Earth and Sun
    u_earth = -sat_vec / np.linalg.norm(sat_vec)
    u_sun = (sun_vec - sat_vec)
    u_sun /= np.linalg.norm(u_sun)

    angle_rad = np.arccos(np.clip(np.dot(u_earth, u_sun), -1.0, 1.0))
    angle_deg = np.degrees(angle_rad)

    if shadow_log is not None and sat_id is not None and time is not None:
        shadow_log.append({
            "time": time.iso,
            "satellite_id": sat_id,
            "sun_angle_deg": angle_deg,
            "in_shadow": angle_deg < threshold_angle_deg
        })

    return angle_deg < threshold_angle_deg


# -------------- Solar Farm Setup ----------------------------
solar_df = (
    pd.read_csv("solar_data/solar_sites.csv")
      .sort_values("Capacity (MW)", ascending=False)
      .head(500)
)

farm_xyz          = []          # list of xyz for KD-tree
farm_itrs_vec     = {}          # name ➜ xyz  (km, ITRS at epoch)
farm_unit_normals = {}          # name ➜ unit normal (for dot products)

farm_locations = {
    row["Project Name"]: EarthLocation(
        lon=row["Longitude"] * u.deg,
        lat=row["Latitude"] * u.deg,
        height=0 * u.m
    )
    for _, row in solar_df.iterrows()
}

for _, row in solar_df.iterrows():
    loc = EarthLocation(row["Longitude"] * u.deg,
                        row["Latitude"] * u.deg,
                        0 * u.m)
    xyz = loc.get_itrs(obstime=epoch).cartesian.xyz.to_value(u.km)
    farm_xyz.append(xyz)
    farm_itrs_vec[row["Project Name"]]     = xyz
    farm_unit_normals[row["Project Name"]] = xyz / np.linalg.norm(xyz)

solar_tree   = KDTree(farm_xyz)
farm_capacity = solar_df.set_index("Project Name")["Capacity (MW)"].to_dict()




# ----------------------- Irradiance utility --------------------------
def compute_reflection_irradiance(sat_pos, sun_pos, target_pos, target_normal,
                                  sun_half_angle_deg=0.25, solar_constant=1361.0,
                                  A_mirror_m2=2500.0, return_all=False):
    u_sun = sun_pos - sat_pos
    u_sun /= np.linalg.norm(u_sun)

    u_target = target_pos - sat_pos
    d_target = np.linalg.norm(u_target)
    u_target /= d_target

    mirror_normal = u_sun + u_target
    mirror_normal /= np.linalg.norm(mirror_normal)

    reflected_dir = u_sun - 2 * np.dot(u_sun, mirror_normal) * mirror_normal
    reflected_dir /= np.linalg.norm(reflected_dir)

    cos_phi = np.clip(np.dot(reflected_dir, target_normal), 0, 1)
    phi_deg = np.degrees(np.arccos(cos_phi))

    def compute_I0(power_transmitted, r_spot):
        coeff = 8 * (1 - 8 / np.pi**2)
        I0_km2 = power_transmitted / (coeff * r_spot**2)
        return I0_km2

    theta_sun = np.radians(sun_half_angle_deg)
    r_spot = d_target * np.sin(theta_sun)
    mirror_angle = np.arccos(np.clip(np.dot(mirror_normal, u_sun), 0, 1))
    power_transmitted = A_mirror_m2 * np.cos(mirror_angle) * solar_constant * cos_phi
    I0 = compute_I0(power_transmitted, r_spot * 1000)
    a = 1 / cos_phi if cos_phi > 1e-6 else np.inf

    def euler_from_vector(vec):
        x, y, z = vec
        yaw = np.arctan2(y, x)
        hyp = np.hypot(x, y)
        pitch = np.arctan2(z, hyp)
        roll = 0.0
        return np.degrees([roll, pitch, yaw])

    roll, pitch, yaw = euler_from_vector(mirror_normal)

    if return_all:
        return I0 * cos_phi, r_spot, phi_deg, power_transmitted

    return {
        "mirror_normal_x": mirror_normal[0],
        "mirror_normal_y": mirror_normal[1],
        "mirror_normal_z": mirror_normal[2],
        "reflected_dir_x": reflected_dir[0],
        "reflected_dir_y": reflected_dir[1],
        "reflected_dir_z": reflected_dir[2],
        "phi_deg": phi_deg,
        "cos_phi": cos_phi,
        "r_spot_km": r_spot,
        "ellipse_a": a,
        "I0_W_km2": I0,
        "euler_roll_deg": roll,
        "euler_pitch_deg": pitch,
        "euler_yaw_deg": yaw,
        "power_transmitted_W": power_transmitted,
        "max_irradiance_W_m2": I0 * cos_phi * RE * AINT
    }


# ----- Fn to compute handoff duration (bang bang method, accurate to 20%) ------
def compute_handoff_duration(sid: int, prev: str, cand: str, dec_t: Time):
    """
    Compute the slew (hand-off) time between two real farms;
    returns 0.0 if either is "NULL".
    """
    if prev == "NULL" or cand == "NULL":
        return 0.0

    # Satellite & sun positions at decision time
    rep_dec = CartesianRepresentation(
        sat_states[sid][0] * u.km,
        sat_states[sid][1] * u.km,
        sat_states[sid][2] * u.km
    )
    sc_dec = SkyCoord(rep_dec, frame=GCRS(obstime=dec_t))
    pos_km = sc_dec.cartesian.xyz.to_value(u.km)
    sunp   = get_sun(dec_t).transform_to(GCRS(obstime=dec_t)).cartesian.xyz.to_value(u.km)

    def norm_vec(f_loc: EarthLocation) -> np.ndarray:
        """Unit normal for farm at dec_t."""
        xyz = f_loc.get_gcrs(obstime=dec_t).cartesian.xyz.to_value(u.km)
        u_s = (sunp - pos_km) / np.linalg.norm(sunp - pos_km)
        u_t = (xyz   - pos_km) / np.linalg.norm(xyz   - pos_km)
        n   = u_s + u_t
        return n / np.linalg.norm(n)

    n1    = norm_vec(farm_locations[prev])
    n2    = norm_vec(farm_locations[cand])
    theta = np.arccos(np.clip(np.dot(n1, n2), -1.0, 1.0))
    return np.sqrt(2 * theta / ALPHA_MAX)

# ------------------------- J2 model for SSO prop ---------------------------
def rk4_step_j2(state, dt):
    k = 398600.4418  # km^3/s^2
    R = 6378.137     # km
    J2 = 1.08263e-3

    def j2_accel(s):
        x, y, z = s[:3]
        r2 = x**2 + y**2 + z**2
        r = np.sqrt(r2)
        factor = 1.5 * J2 * k * R**2 / r**5
        zx = z / r
        ax = x * (5 * zx**2 - 1)
        ay = y * (5 * zx**2 - 1)
        az = z * (5 * zx**2 - 3)
        return np.array([factor * ax, factor * ay, factor * az])

    def accel(s):
        r = np.linalg.norm(s[:3])
        a_kepler = -k * s[:3] / r**3
        a_j2 = j2_accel(s)
        return np.concatenate((s[3:], a_kepler + a_j2))

    k1 = accel(state)
    k2 = accel(state + 0.5 * dt * k1)
    k3 = accel(state + 0.5 * dt * k2)
    k4 = accel(state + dt * k3)

    return state + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)

# ----------------------------  Initial orbits --------------------------------
alt_km = 650
a = (EARTH_RADIUS_KM + alt_km) * u.km
ecc = 0.0 * u.one
inc = 97.986 * u.deg
sun_ra = get_sun(epoch).transform_to(GCRS(obstime=epoch)).ra.to_value(u.deg)
raan = ((sun_ra + 90) % 360) * u.deg
argp = 0 * u.deg

sat_states = []
for i in range(N_SATS):
    nu = (i * 360.0 / N_SATS) * u.deg
    orbit = Orbit.from_classical(Earth, a, ecc, inc, raan, argp, nu, epoch)
    sat_states.append(np.hstack((orbit.r.to_value(u.km), orbit.v.to_value(u.km / u.s))))

# ------------------------ Logging + Housekeeping -------------------------------
future_farm     = [{} for _ in range(N_STEPS)]
farm_assignment = {i: {"farm": "NULL", "start": epoch} for i in range(N_SATS)}
switch_log      = []
farm_names = sorted(farm_locations.keys())
irradiance_matrix = pd.DataFrame(
    0.0,
    index=[(epoch + TimeDelta(i * STEP_SEC * u.s)).iso for i in range(N_STEPS)],
    columns=farm_names
)
shadow_log = []

times = epoch + TimeDelta(np.arange(N_STEPS) * STEP_SEC * u.s)

# --------------------------- Run simulation -------------------------------------
for step in range(N_STEPS):
    current_time = epoch + TimeDelta(STEP_SEC * step * u.s)
    sun_gcrs = get_sun(current_time).transform_to(GCRS(obstime=current_time))
    sun_itrs = sun_gcrs.transform_to(ITRS(obstime=current_time))
    sun_pos_gcrs = np.array([sun_gcrs.cartesian.x.value, sun_gcrs.cartesian.y.value, sun_gcrs.cartesian.z.value])
    sun_pos_itrs = np.array([sun_itrs.x.to_value(u.km), sun_itrs.y.to_value(u.km), sun_itrs.z.to_value(u.km)])

    results = []

    for i, state in enumerate(sat_states):
        coord = SkyCoord(x=state[0], y=state[1], z=state[2], unit="km",
                         representation_type='cartesian', frame=GCRS(obstime=current_time))
        itrs = coord.transform_to(ITRS(obstime=current_time))
        r_ecef = [itrs.x.to_value(u.km), itrs.y.to_value(u.km), itrs.z.to_value(u.km)]

        dist, idx = solar_tree.query(r_ecef)
        closest_row = solar_df.iloc[idx]
        farm_loc = EarthLocation(lat=closest_row["Latitude"] * u.deg, lon=closest_row["Longitude"] * u.deg)
        target_coord = farm_loc.get_gcrs(obstime=current_time)
        target_pos = np.array([
            target_coord.cartesian.x.to_value(u.km),
            target_coord.cartesian.y.to_value(u.km),
            target_coord.cartesian.z.to_value(u.km)
        ])
        target_normal = target_pos / np.linalg.norm(target_pos)
        sat_pos = np.array([coord.cartesian.x.value, coord.cartesian.y.value, coord.cartesian.z.value])

        neigh    = solar_tree.query_ball_point(r_ecef, r=2000.0)
        best_I   = -np.inf
        best_f   = "NULL"
        best_dist = np.nan

        # Look for desirable farms near the satellite
        for idx in neigh:
            name = solar_df.iloc[idx]["Project Name"]
            loc  = farm_locations[name]
            tgt_xyz = loc.get_gcrs(obstime=current_time).cartesian.xyz.to_value(u.km)
            tgt_norm = farm_unit_normals[name]

            I, r_spot_km, phi_deg, power_tx = compute_reflection_irradiance(
                sat_pos, sun_pos_itrs, tgt_xyz, tgt_norm, return_all=True
            )
            if I > best_I:
                best_I, best_f = I, name
                best_dist      = np.linalg.norm(tgt_xyz - r_ecef)

        if best_I < 1e-2:
            best_f = "NULL"

        in_shadow = is_in_earth_shadow_debug(
            r_ecef,
            sun_pos_itrs,
            threshold_angle_deg=68,
            sat_id=i,
            time=current_time,
            shadow_log=shadow_log
        )

        future_farm[step][i] = best_f

        results.append({
            "satellite_id": i,
            "x_km": r_ecef[0],
            "y_km": r_ecef[1],
            "z_km": r_ecef[2],
            "closest_farm": closest_row["Project Name"],
            "distance_km": round(dist, 2),
            "near": best_f != "NULL",
            "closest_farm": best_f,
            "max_irradiance_W_m2": best_I,
            "r_spot_km": r_spot_km,
            "phi_deg": phi_deg,
            "power_transmitted_W": power_tx,
            "sun_x_km": sun_pos_itrs[0],
            "sun_y_km": sun_pos_itrs[1],
            "sun_z_km": sun_pos_itrs[2],
        })

        assigned_farm = future_farm[step][i]
        if assigned_farm != "NULL":
                irradiance_matrix.loc[current_time.iso, assigned_farm] = best_I

        sat_states[i] = rk4_step_j2(state, STEP_SEC)


    # Loop to go over whether we "should" hit a farm or not (very basic)
    if step >= LOOKAHEAD:
        dec_t = epoch + TimeDelta((step - LOOKAHEAD) * STEP_SEC * u.s)
        for sid in range(N_SATS):
            preds = [future_farm[step-LOOKAHEAD+k][sid] for k in range(LOOKAHEAD)]
            cand  = preds[0]
            if all(p == cand for p in preds):
                prev = farm_assignment[sid]["farm"]
                if cand != prev:
                    prev = farm_assignment[sid]["farm"]
                    if cand == prev:
                        continue

                    T_sec = compute_handoff_duration(sid, prev, cand, dec_t)
                    end_t = dec_t + TimeDelta(T_sec * u.s)
                    farm_duration = 0.0 if prev == "NULL" \
                        else (dec_t - farm_assignment[sid]["start"]).to(u.s).value

                    switch_log.append({
                        "satellite_id":  sid,
                        "from_farm":     prev,
                        "from_capacity":   farm_capacity.get(prev, float("nan")),
                        "to_farm":       cand,
                        "to_capacity":     farm_capacity.get(cand, float("nan")),
                        "start_time":    dec_t.iso,
                        "end_time":      end_t.iso,
                        "duration_s":    T_sec,
                        "farm_duration": farm_duration
                    })

                    # log
                    farm_assignment[sid] = {"farm": cand, "start": end_t}


    if step % WRITE_INTERVAL == 0:
        hour = step // WRITE_INTERVAL
        df = pd.DataFrame(results)
        df.to_csv(f"generated_data/satellite_positions_hour_{hour}.csv", index=False)
        print(f"✅ Wrote satellite_positions_hour_{hour}.csv")


# -------------------------- Flush values -------------------------------------
t_end = epoch + TimeDelta(N_STEPS * STEP_SEC * u.s)
for sid, entry in farm_assignment.items():
    f  = entry["farm"]
    st = entry["start"]
    if f != "NULL":
        switch_log.append({
            "satellite_id": sid,
            "from_farm":   f,
            "to_farm":     None,
            "start_time":  st.iso,
            "end_time":    t_end.iso,
            "duration_s":  (t_end - st).to(u.s).value
        })

pd.DataFrame(switch_log).to_csv("generated_data/farm_assignment_durations.csv", index=False)
print("✅ wrote farm_assignment_durations.csv")

irradiance_matrix.to_csv("generated_data/farm_irradiance_matrix.csv")
print("✅ wrote farm_irradiance_matrix.csv")

pd.DataFrame(shadow_log).to_csv("generated_data/satellite_shadow_log.csv", index=False)
print("✅ wrote satellite_shadow_log.csv")


   