"""
Pull plot
=========

Compare data and model with pulls.
"""

from plothist_utils import get_dummy_data

df = get_dummy_data()

from plothist import get_color_palette, make_hist

# Define the histograms

key = "variable_1"
range = (-9, 12)
category = "category"

# Define masks
signal_mask = df[category] == 7
data_mask = df[category] == 8

background_categories = [0, 1, 2]
background_categories_labels = [f"c{i}" for i in background_categories]
background_categories_colors = get_color_palette(
    "cubehelix", len(background_categories)
)

background_masks = [df[category] == p for p in background_categories]

# Make histograms
data_hist = make_hist(df[key][data_mask], bins=50, range=range, weights=1)
background_hists = [
    make_hist(df[key][mask], bins=50, range=range, weights=1)
    for mask in background_masks
]

# Optional: scale to data
background_scaling_factor = data_hist.sum().value / sum(background_hists).sum().value
background_hists = [background_scaling_factor * h for h in background_hists]

###
from plothist import plot_data_model_comparison

fig, ax_main, ax_comparison = plot_data_model_comparison(
    data_hist=data_hist,
    stacked_components=background_hists,
    stacked_labels=background_categories_labels,
    stacked_colors=background_categories_colors,
    xlabel=rf"${key}\,\,[TeV/c^2]$",
    ylabel="Candidates per 0.42 $TeV/c^2$",
    comparison="pull",
)

fig.savefig("model_examples_pull.svg", bbox_inches="tight")
