.. note::
    :class: sphx-glr-download-link-note

    Click :ref:`here <sphx_glr_download_auto_examples_example_logistic_group_lasso.py>` to download the full example code
.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_example_logistic_group_lasso.py:


GroupLasso for logistic regression
==================================

A sample script for group lasso regression

Setup
-----


.. code-block:: default


    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.linear_model import Ridge
    from sklearn.pipeline import Pipeline

    from group_lasso import LogisticGroupLasso

    np.random.seed(0)
    LogisticGroupLasso.LOG_LOSSES = True









Set dataset parameters
----------------------


.. code-block:: default

    group_sizes = [np.random.randint(10, 20) for i in range(50)]
    active_groups = [np.random.randint(2) for _ in group_sizes]
    groups = np.concatenate(
        [size * [i] for i, size in enumerate(group_sizes)]
    )
    num_coeffs = sum(group_sizes)
    num_datapoints = 10000
    noise_std = 1









Generate data matrix
--------------------


.. code-block:: default

    X = np.random.standard_normal((num_datapoints, num_coeffs))









Generate coefficients
---------------------


.. code-block:: default

    w = np.concatenate(
        [
            np.random.standard_normal(group_size) * is_active
            for group_size, is_active in zip(group_sizes, active_groups)
        ]
    )
    w = w.reshape(-1, 1)
    true_coefficient_mask = w != 0
    intercept = 2









Generate regression targets
---------------------------


.. code-block:: default

    y_true = X @ w + intercept
    y = y_true + np.random.randn(*y_true.shape) * noise_std
    p = 1 / (1 + np.exp(-y))
    p_true = 1 / (1 + np.exp(-y_true))
    c = np.random.binomial(1, p_true)









View noisy data and compute maximum accuracy
--------------------------------------------


.. code-block:: default

    plt.figure()
    plt.plot(p, p_true, ".")
    plt.xlabel("Noisy probabilities")
    plt.ylabel("Noise-free probabilities")
    # Use noisy y as true because that is what we would have access
    # to in a real-life setting.
    best_accuracy = ((p_true > 0.5) == c).mean()





.. image:: /auto_examples/images/sphx_glr_example_logistic_group_lasso_001.png
    :class: sphx-glr-single-img





Generate estimator and train it
-------------------------------


.. code-block:: default

    gl = LogisticGroupLasso(
        groups=groups,
        group_reg=0.05,
        l1_reg=0,
        scale_reg="inverse_group_size",
        subsampling_scheme=1,
        supress_warning=True,
    )

    gl.fit(X, c)






.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /home/yngvem/Programming/morro/group-lasso/src/group_lasso/_group_lasso.py:749: UserWarning: Subsampling is not stable for logistic regression group lasso.
      "Subsampling is not stable for logistic regression group lasso."




Extract results and compute performance metrics
-----------------------------------------------


.. code-block:: default


    # Extract info from estimator
    pred_c = gl.predict(X)
    sparsity_mask = gl.sparsity_mask_
    w_hat = gl.coef_

    # Compute performance metrics
    accuracy = (pred_c == c).mean()

    # Print results
    print(f"Number variables: {len(sparsity_mask)}")
    print(f"Number of chosen variables: {sparsity_mask.sum()}")
    print(f"Accuracy: {accuracy}, best possible accuracy = {best_accuracy}")






.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    Number variables: 720
    Number of chosen variables: 334
    Accuracy: 0.9435, best possible accuracy = 0.9698




Visualise regression coefficients
---------------------------------


.. code-block:: default

    coef = gl.coef_[:, 1] - gl.coef_[:, 0]
    plt.figure()
    plt.plot(w / np.linalg.norm(w), ".", label="True weights")
    plt.plot(
        coef / np.linalg.norm(coef),
        ".",
        label="Estimated weights",
    )

    plt.figure()
    plt.plot([w.min(), w.max()], [coef.min(), coef.max()], "gray")
    plt.scatter(w, coef, s=10)
    plt.ylabel("Learned coefficients")
    plt.xlabel("True coefficients")

    plt.figure()
    plt.plot(gl.losses_)

    plt.show()



.. rst-class:: sphx-glr-horizontal


    *

      .. image:: /auto_examples/images/sphx_glr_example_logistic_group_lasso_002.png
            :class: sphx-glr-multi-img

    *

      .. image:: /auto_examples/images/sphx_glr_example_logistic_group_lasso_003.png
            :class: sphx-glr-multi-img

    *

      .. image:: /auto_examples/images/sphx_glr_example_logistic_group_lasso_004.png
            :class: sphx-glr-multi-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /home/yngvem/Programming/morro/group-lasso/examples/example_logistic_group_lasso.py:132: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
      plt.show()





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  2.277 seconds)


.. _sphx_glr_download_auto_examples_example_logistic_group_lasso.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download

     :download:`Download Python source code: example_logistic_group_lasso.py <example_logistic_group_lasso.py>`



  .. container:: sphx-glr-download

     :download:`Download Jupyter notebook: example_logistic_group_lasso.ipynb <example_logistic_group_lasso.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
