"""
###################################################################
#                                                                 #
#      "Gravity for Undergrads" Chapter by Y.V. Yotov             #
#      Python Implementation of Original Stata Code               #
#                                                                 #
#      Author: Ohyun Kwon & Mario Larch                           #
#      Date: November 28, 2025                                    #
#      Python Version: 3.12                                       #
#                                                                 #
#      Script Description:                                        #
#      - Part A: Imports and installs required libraries.         #
#      - Part B: Loads and prepares the data.                     #
#      - Part C: Replicates gravity model estimations.            #
#      - Part D: Generates scatter plot and exports results.      #
#                                                                 #
###################################################################
"""

# ************************************************************************
# A. Import required libraries
# ************************************************************************
# Install packages if not already installed
import subprocess
import sys

def install_if_missing(package):
    try:
        __import__(package)
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        print(f"{package} installed successfully.")

# List of required packages
required_packages = ["matplotlib", "pandas", "numpy", "pyfixest", "scipy"]

# Check and install each package
for pkg in required_packages:
    install_if_missing(pkg)

# Import necessary packages
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pyfixest as pf
from scipy import stats

# ************************************************************************
# B. Load and prepare the data
# ************************************************************************

# Load the data
data = pd.read_csv("Gravity_Undergrads_Data.csv")

# Transform variables (create log transformations)
data['ln_trade'] = np.log(data['Trade'].where(data['Trade'] > 0))
data['ln_dist'] = np.log(data['Distance'].where(data['Distance'] > 0))
data['ln_gdp_exp'] = np.log(data['GDP_Exporter'].where(data['GDP_Exporter'] > 0))
data['ln_gdp_imp'] = np.log(data['GDP_Importer'].where(data['GDP_Importer'] > 0))


# ************************************************************************
# C. Estimate gravity models
# ************************************************************************
# C.1 Naive Gravity (OLS without fixed effects)
# Prepare 2023 cross-sectional data
data_2023 = data[data['Year'] == 2023].copy()
data_2023 = data_2023.replace([np.inf, -np.inf], np.nan)
data_2023 = data_2023.dropna(subset=['ln_trade', 'ln_dist', 'RTA', 'EU', 'ln_gdp_exp', 'ln_gdp_imp'])

naive = pf.feols('ln_trade ~ ln_dist + RTA + EU + ln_gdp_exp + ln_gdp_imp',
                 data=data_2023, vcov='iid')
print("\n###\nNaive Gravity:")
print(naive.summary())

# C.2 Structural Gravity (with exporter and importer fixed effects)
structural = pf.feols('ln_trade ~ ln_dist + RTA + EU | Exporter + Importer',
                      data=data_2023, vcov='iid')
print("\n###\nStructural Gravity:")
print(structural.summary())

# C.3 Panel Gravity (with exporter-year and importer-year fixed effects)
# Prepare panel data
data_clean = data.dropna(subset=['ln_trade', 'ln_dist', 'RTA', 'EU', 'Exporter', 'Importer', 'Year']).copy()
data_clean['Exporter_Year'] = data_clean['Exporter'].astype(str) + '_' + data_clean['Year'].astype(str)
data_clean['Importer_Year'] = data_clean['Importer'].astype(str) + '_' + data_clean['Year'].astype(str)

panel_hdfe = pf.feols('ln_trade ~ ln_dist + RTA + EU | Exporter_Year + Importer_Year',
                      data=data_clean, vcov='iid')
print("\n###\nPanel Gravity:")
print(panel_hdfe.summary())

# C.4 HDFE Gravity (with pair fixed effects)
data_clean['Importer_Exporter'] = data_clean['Importer'].astype(str) + '_' + data_clean['Exporter'].astype(str)
hdfe = pf.feols('ln_trade ~ RTA + EU | Exporter_Year + Importer_Year + Importer_Exporter',
                data=data_clean, vcov='iid')
print("\n###\nHDFE Gravity:")
print(hdfe.summary())

# C.5 Multiplicative Gravity (PPML with heteroskedastic-robust standard errors)
# Prepare data for PPML estimation
data_ppml = data.dropna(subset=['Trade', 'RTA', 'EU', 'Exporter', 'Importer', 'Year']).copy()
data_ppml['Exporter_Year'] = data_ppml['Exporter'].astype(str) + '_' + data_ppml['Year'].astype(str)
data_ppml['Importer_Year'] = data_ppml['Importer'].astype(str) + '_' + data_ppml['Year'].astype(str)
data_ppml['Importer_Exporter'] = data_ppml['Importer'].astype(str) + '_' + data_ppml['Exporter'].astype(str)

multiplicative = pf.fepois('Trade ~ RTA + EU | Exporter_Year + Importer_Year + Importer_Exporter',
                           data=data_ppml, vcov='hetero')
print("\n###\nMultiplicative Gravity (PPML):")
print(multiplicative.summary())

# ************************************************************************
# D. Generate scatter plot and export results
# ************************************************************************

# Get the fitted values
data_ppml['predicted_trade'] = multiplicative._Y_hat_response

# Create the scatter plot
predicted_trade = data_ppml['predicted_trade']
actual_trade = data_ppml['Trade']

# Create the scatter plot
plt.figure(figsize=(10, 8))
plt.scatter(predicted_trade, actual_trade,
           alpha=0.6, s=20, color='blue', edgecolors='none')

# Add linear fitted line
slope, intercept, r_value, p_value, std_err = stats.linregress(predicted_trade, actual_trade)
line_x = np.array([predicted_trade.min(), predicted_trade.max()])
line_y = slope * line_x + intercept
plt.plot(line_x, line_y, color='red', linewidth=2, label=f'Fitted line (R² = {r_value**2:.3f})')

# Labels and title
plt.xlabel('Trade, Predicted from Gravity', fontsize=12)
plt.ylabel('Actual Trade', fontsize=12)
plt.title('Scatterplot of Trade vs Predicted', fontsize=14)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


# Export regression results to HTML table
pf.etable([naive, structural, panel_hdfe, hdfe, multiplicative],
          coef_fmt='b \n (se)',
          signif_code=[0.01, 0.05, 0.10],
          type='gt',
          file_name='gravity_estimates.html')
print("\nRegression results have been exported to 'gravity_estimates.html'.")


