# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/40_make.ipynb.

# %% auto 0
__all__ = ['make_rings', 'rings', 'make_circles', 'circles', 'make_orbits', 'orbits', 'make_diamonds', 'diamonds', 'jacks']

# %% ../nbs/40_make.ipynb 6
#| export


# %% ../nbs/40_make.ipynb 8
#| export


# %% ../nbs/40_make.ipynb 11
#| export

# %% ../nbs/40_make.ipynb 13
try: import numpy as np
except ImportError: ...

try: import pandas as pd
except ImportError: ...

try: import matplotlib.pyplot as plt
except ImportError: ...

try: import seaborn as sns
except ImportError: ...

# %% ../nbs/40_make.ipynb 15
#| export


# %% ../nbs/40_make.ipynb 17
from quac import real, intq

# %% ../nbs/40_make.ipynb 19
from .atyp import XYArray, LabelArray
from .trig import rotate, ntheta, catrad, chord
from .data import npxy, dfxy, addx
from .cats import negcats, catdists
from .poly import orbit, circle, subring, diamond

# %% ../nbs/40_make.ipynb 22
def make_rings(
    n: int, 
    p: int = 400, 
    r: float = .8, 
    zscale: float = 0.02, 
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3
) -> tuple[XYArray, LabelArray]:
    '''Generates a dataset of rings.
    
    Parameters
    ----------
    n : int
        Number of rings to generate.
        
    p : int, default: 400
        Number of points per ring, by default 400.
    
    r : float, default: 0.8
        Base radius for the rings, by default .8.
    
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default 5.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
    
    Returns
    -------
    tuple[XYArray, LabelArray]
        The generated dataset of rings.
    '''
    data, cats = npxy(seed=seed)
    if ncats is None: ncats = n
    
    rlocs = ntheta(ncats)
    r *= (0.5 * chord(rlocs, sx=0.5, sy=0.5))
    
    for i, t in enumerate(rlocs):
        newx = subring(p, r=r, t=t)
        data, cats = addx(i, newx, data, cats, zscale, distcats, ncats, seed)
    return data, cats

def rings(
    n: int = 5, 
    p: int = 400, 
    r: float = .8, 
    zscale: float = 0.02, 
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3,
    label: str = 'label', 
    use_index: bool = True
) -> pd.DataFrame:
    '''Generates a dataset of rings.
    
    Parameters
    ----------
    n : int, default: 5
        Number of rings to generate.
        
    p : int, default: 400
        Number of points per ring, by default 400.
    
    r : float, default: 0.8
        Base radius for the rings, by default .8.
    
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default 5.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    label : str, default: 'label'
        Name of the label column in the output DataFrame, by default 'label'.
        
    use_index : bool, default: True
        Whether to use labels as DataFrame index, by default True.
    
    Returns
    -------
    pd.DataFrame
        The generated labeled dataset of rings.
        
    Examples
    --------
    >>> df_dist, df_cats = rings(zscale=0.02, ncats=5), rings(zscale=0.02, distcats=False)
    >>> df_dist.head()
    |   label |         x |           y |
    |--------:|----------:|------------:|
    |       0 | -0.267671 | -0.00960102 |
    |       0 | -0.273289 |  0.0159537  |
    |       0 | -0.257964 | -0.0112079  |
    |       0 | -0.285742 |  0.00354575 |
    |       0 | -0.260981 |  0.0183896  |
    
    
    >>> fig = plt.figure(figsize=(2 * 4, 4))
    ... fig.add_subplot(1, 2, 1)
    ... sns.scatterplot(data = df_dist, x='x', y='y', hue='label', palette='pastel')
    ... fig.add_subplot(1, 2, 2)
    ... sns.scatterplot(data = df_cats, x='x', y='y', hue='label', palette='pastel')
    '''
    x, y = make_rings(n, p, r, zscale, distcats, ncats, seed)
    return dfxy(x, y, label, use_index)

# %% ../nbs/40_make.ipynb 27
def make_circles(
    n: int, 
    p: int = 400, 
    zscale: float = 0.02, 
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3,
) -> tuple[XYArray, LabelArray]:
    '''Generates a dataset of circle.
    
    Parameters
    ----------
    n : int
        Number of circles to generate.
        
    p : int, default: 400
        Number of points per circle, by default 400.
    
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default 5.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    Returns
    -------
    tuple[XYArray, LabelArray]
        The generated dataset of circles.
    '''
    data, cats = npxy(seed=seed)
    if ncats is None: ncats = n
    
    radii = np.linspace(0, 1, n + 1, endpoint=True)[1:]
    for i, r in enumerate(radii):
        newx = circle(p, r)
        data, cats = addx(i, newx, data, cats, zscale, distcats, ncats, seed)
    if distcats: cats = catdists(data, ncats)
    return data, cats

def circles(
    n: int = 5, 
    p: int = 400, 
    zscale: float = 0.02, 
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3,
    label: str = 'label', 
    use_index: bool = True
) -> pd.DataFrame:
    '''Generates a dataset of circle.
    
    Parameters
    ----------
    n : int, default: 5
        Number of circles to generate.
        
    p : int, default: 400
        Number of points per circle, by default 400.
    
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default 5.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    label : str, default: 'label'
        Name of the label column in the output DataFrame, by default 'label'.
        
    use_index : bool, default: True
        Whether to use labels as DataFrame index, by default True.
        
    Returns
    -------
    pd.DataFrame
        The generated dataset of circles.
        
    Examples
    --------
    >>> df_dist, df_cats = circles(5, zscale=0.02, ncats=4), circles(5, zscale=0.02, ncats=4, distcats=False)
    >>> df_dist.head()
    |   label |        x |           y |
    |--------:|---------:|------------:|
    |       0 | 0.174754 |  0.0222349  |
    |       0 | 0.207543 | -0.00345989 |
    |       0 | 0.208733 |  0.0195313  |
    |       0 | 0.202914 |  0.0124247  |
    |       0 | 0.231898 |  0.0187599  |
    
    >>> fig = plt.figure(figsize=(2 * 4, 4))
    ... fig.add_subplot(1, 2, 1)
    ... sns.scatterplot(data = df_dist, x='x', y='y', hue='label', palette='pastel')
    ... fig.add_subplot(1, 2, 2)
    ... sns.scatterplot(data = df_cats, x='x', y='y', hue='label', palette='pastel')
    '''
    x, y = make_circles(n, p, zscale, distcats, ncats, seed)
    return dfxy(x, y, label, use_index)

# %% ../nbs/40_make.ipynb 32
def make_orbits(
    n: int, 
    p: int = 400, 
    w: real = .5, 
    zscale: float = .02, 
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3,
) -> tuple[XYArray, LabelArray]:
    '''Generates a dataset of orbits.

    Parameters
    ----------
    n : int
        Number of orbit to generate.
        
    p : int, default: 400
        Number of points per orbit, by default 400.
        
    w : real, default: .5
        Width of the orbits.
    
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default 5.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    Returns
    -------
    tuple[XYArray, LabelArray]
        The generated dataset of orbit.
    '''
    data, cats = npxy(seed=seed)
    if ncats is None: ncats = n
    
    for i in range(n):
        newx = orbit(p, w)
        newx = rotate(newx, catrad(i, n))
        data, cats = addx(i, newx, data, cats, zscale, distcats, ncats, seed)
    return data, cats

def orbits(
    n: int = 5, 
    p: int = 400, 
    w: real = 0.2, 
    zscale: float = 0.2, 
    distcats: bool = True,
    ncats: intq = 5, 
    seed: int = 3,
    label: str = 'label', 
    use_index: bool = True
) -> pd.DataFrame:
    '''Generates a dataset of orbits.
    
    Parameters
    ----------
    n : int, default: 5
        Number of circles to generate.
        
    p : int, default: 400
        Number of points per orbit, by default 400.
        
    w : real, default: .2
        Width of the orbits.
    
    zscale : float, default: 0.2
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default 5.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    label : str, default: 'label'
        Name of the label column in the output DataFrame, by default 'label'.
        
    use_index : bool, default: True
        Whether to use labels as DataFrame index, by default True.
        
    Returns
    -------
    pd.DataFrame
        The generated dataset of orbit.
        
    Examples
    --------
    >>> df_dist, df_cats = orbits(zscale=0.02, ncats=5), orbits(zscale=0.02, distcats=False)
    >>> df_dist.head()
    |   label |        x |           y |
    |--------:|---------:|------------:|
    |       0 | 0.209824 |  0.00985485 |
    |       0 | 0.202101 |  0.047002   |
    |       0 | 0.194644 | -0.0017663  |
    |       0 | 0.188793 |  0.0697399  |
    |       0 | 0.197442 |  0.0831813  |
    
    >>> fig = plt.figure(figsize=(2 * 4, 4))
    ... fig.add_subplot(1, 2, 1)
    ... sns.scatterplot(data = df_dist, x='x', y='y', hue='label', palette='pastel')
    ... fig.add_subplot(1, 2, 2)
    ... sns.scatterplot(data = df_cats, x='x', y='y', hue='label', palette='pastel')
    '''
    x, y = make_orbits(n, p, w, zscale, distcats, ncats, seed)
    return dfxy(x, y, label, use_index)

# %% ../nbs/40_make.ipynb 37
def make_diamonds(
    n: int, 
    p: int = 400,
    w: real = .5, 
    h: real = 1, 
    zscale: float = .02,
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3,
) -> tuple[XYArray, LabelArray]:
    '''Generates a dataset of diamonds.
    
    Parameters
    ----------
    n : int
        Number of diamonds to generate.
        
    p : int, default: 400
        Number of points per orbit, by default 400.
        
    w : real, default: .5
        Width of the diamonds.
        
    h : real, default: 1
        Height of the diamonds.
    
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default None.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    label : str, default: 'label'
        Name of the label column in the output DataFrame, by default 'label'.
        
    use_index : bool, default: True
        Whether to use labels as DataFrame index, by default True.
        
    Returns
    -------
    tuple[XYArray, LabelArray]
        The generated dataset of diamonds.
    '''
    data, cats = npxy(seed=seed)
    if ncats is None: ncats = n

    for i in range(n):
        newx = diamond(p, w, h)
        newx = rotate(newx, catrad(i, n))
        data, cats = addx(i, newx, data, cats, zscale, distcats, ncats, seed)
    return data, cats

def diamonds(
    n: int = 5, 
    p: int = 400, 
    w: real = .5, 
    h: real = 1, zscale: float = .02, 
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3,
    label: str = 'label', 
    use_index: bool = True
) -> pd.DataFrame:    
    '''Generates a dataset of diamonds.
    
    Parameters
    ----------
    n : int
        Number of diamonds to generate.
        
    p : int, default: 400
        Number of points per orbit, by default 400.
        
    w : real, default: .5
        Width of the diamonds.
        
    h : real, default: 1
        Height of the diamonds.
    
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default None.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    Returns
    -------
    pd.DataFrame
        The generated dataset of diamonds.
        
    Examples
    --------
    >>> df_dist, df_cats = diamonds(zscale=0.02, ncats=5), diamonds(zscale=0.02, distcats=False)
    >>> df_dist.head()
    |   label |         x |        y |
    |--------:|----------:|---------:|
    |       2 | -0.245243 | 0.486684 |
    |       2 | -0.270382 | 0.488943 |
    |       2 | -0.206727 | 0.48754  |
    |       2 | -0.24259  | 0.527085 |
    |       2 | -0.24874  | 0.557649 |
    >>> fig = plt.figure(figsize=(2 * 4, 4))
    ... fig.add_subplot(1, 2, 1)
    ... sns.scatterplot(data = df_dist, x='x', y='y', hue='label', palette='pastel')
    ... fig.add_subplot(1, 2, 2)
    ... sns.scatterplot(data = df_cats, x='x', y='y', hue='label', palette='pastel')
    '''
    x, y = make_diamonds(n, p, w, h, zscale, distcats, ncats, seed)
    return dfxy(x, y, label, use_index)    

# %% ../nbs/40_make.ipynb 42
def jacks(
    n: int = 5, 
    p: int = 400, 
    zscale: float = 0.02, 
    distcats: bool = True, 
    ncats: intq = 5, 
    seed: int = 3,
    label: str = 'label', 
    use_index: bool = True, 
    use_neg: bool = False, 
) -> pd.DataFrame:
    '''Generates a dataset of acks.

    Parameters
    ----------
    n : int, default: 5
        Number of orbit to generate.
        
    p : int, default: 400
        Number of points per orbit, by default 400.
        
    zscale : float, default: 0.02
        Scale of the Gaussian noise to add, by default 0.02.
        
    distcats : bool, default: True
        Whether to categorize data by distance from the origin, by default True.
        
    ncats : int, default: 5
        Number of categories, by default 5.
        
    seed : int, default: 3
        Random seed for reproducibility, by default 3.
        
    Returns
    -------
    pd.DataFrame
        The generated dataset of jacks.
        
    Examples
    --------
    >>> df_dist, df_cats = jacks(5, 100, zscale=0.02, ncats=5), jacks(5, 100, zscale=0.02, distcats=False)
    >>> df_dist.head()
    |   label |           x |         y |
    |--------:|------------:|----------:|
    |       0 | -0.029137   | 0.0120227 |
    |       0 |  0.0189926  | 0.0382629 |
    |       0 |  0.0447255  | 0.132869  |
    |       0 | -0.00449615 | 0.190482  |
    |       0 |  0.0157224  | 0.213501  |
    >>> fig = plt.figure(figsize=(2 * 4, 4))
    ... fig.add_subplot(1, 2, 1)
    ... sns.scatterplot(data = df_dist, x='x', y='y', hue='label', palette='pastel')
    ... fig.add_subplot(1, 2, 2)
    ... sns.scatterplot(data = df_cats, x='x', y='y', hue='label', palette='pastel')
    '''
    df = orbits(n, p, 0, zscale, distcats, ncats, seed, label, False)
    if use_neg: df = negcats(df, label)
    df.x += np.random.normal(0, zscale, np.array(df.x).shape)
    df.y += np.random.normal(0, zscale, np.array(df.y).shape)
    if use_index: df.set_index(label, inplace=True)
    return df

# %% ../nbs/40_make.ipynb 47
#| export

