!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00292/Wholesale%20customers%20data.csv -O Wholesale_customers_data.csv
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.clustering import KMeans, BisectingKMeans, GaussianMixture
from pyspark.ml.evaluation import ClusteringEvaluator
import matplotlib.pyplot as plt
spark = SparkSession.builder.appName("WholesaleClustering").getOrCreate()
url = "/content/Wholesale_customers_data.csv"
df = spark.read.csv(url, header=True, inferSchema=True)
feature_cols = ["Fresh", "Milk", "Grocery", "Frozen", "Detergents_Paper", "Delicassen"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features_raw")
data = assembler.transform(df)
scaler = StandardScaler(inputCol="features_raw", outputCol="features", withMean=True, withStd=True)
dataset = scaler.fit(data).transform(data)
models = {
    "KMeans": KMeans(k=5, seed=42, featuresCol="features", predictionCol="prediction"),
    "Bisecting KMeans": BisectingKMeans(k=5, seed=42, featuresCol="features", predictionCol="prediction"),
    "Gaussian Mixture": GaussianMixture(k=5, seed=42, featuresCol="features", predictionCol="prediction")
}
evaluator = ClusteringEvaluator(featuresCol="features", predictionCol="prediction", metricName="silhouette")
for name, model in models.items():
    print(f"\n===== {name} =====")
    fitted = model.fit(dataset)
    predictions = fitted.transform(dataset)
    num_clusters = predictions.select("prediction").distinct().count()
    if num_clusters > 1:
        silhouette = evaluator.evaluate(predictions)
        print(f"{name} Silhouette Score: {silhouette}")
    else:
        print(f"{name} collapsed to 1 cluster; silhouette not available.")
    if name == "Gaussian Mixture":
        print(f"{name} Cluster Means:")
        for row in fitted.gaussiansDF.select("mean").collect():
            print(row["mean"])
    else:
        print(f"{name} Cluster Centers:")
        for center in fitted.clusterCenters():
            print(center)
    predictions.select("features", "prediction").show(5, truncate=False)
    pdf = predictions.toPandas()
    plt.figure(figsize=(7, 5))
    plt.scatter(pdf["features"].apply(lambda x: x[1]),  # Milk
                pdf["features"].apply(lambda x: x[2]),  # Grocery
                c=pdf["prediction"], cmap="rainbow")
    plt.xlabel("Milk (scaled)")
    plt.ylabel("Grocery (scaled)")
    plt.title(f"{name} Clustering (Wholesale Customers)")
    plt.show()
spark.stop()
