Metadata-Version: 2.1
Name: spark-tensorflow-distributor
Version: 0.1.0
Summary: This package helps users do distributed training with TensorFlow on their Spark clusters.
Home-page: https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-distributor
Author: sarthfrey
Author-email: sarth.frey@gmail.com
License: UNKNOWN
Description: # Spark TensorFlow Distributor
        
        This package helps users do distributed training with TensorFlow on their Spark clusters.
        
        ## Installation
        
        This package requires Python 3.6+, `tensorflow>=2.1.0` and `pyspark>=3.0.0` to run.
        To install `spark-tensorflow-distributor`, run:
        
        ```bash
        pip install spark-tensorflow-distributor
        ```
        
        The installation does not install PySpark because for most users, PySpark is already installed.
        If you do not have PySpark installed, you can install it directly:
        
        ```bash
        pip install pyspark>=3.0.*
        ```
        
        Note also that in order to use many features of this package, you must set up Spark custom
        resource scheduling for GPUs on your cluster. See the Spark docs for this.
        
        ## Running Tests
        
        For integration tests, first build the master and worker images and then run the test script.
        
        ```bash
        docker-compose build --build-arg PYTHON_INSTALL_VERSION=3.7
        ./tests/integration/run.sh
        ```
        
        For linting, run the following.
        
        ```bash
        ./tests/lint.sh
        ```
        
        To use the autoformatter, run the following.
        
        ```bash
        yapf --recursive --in-place spark_tensorflow_distributor
        ```
        
        ## Examples
        
        Run following example code in `pyspark` shell:
        
        ```python
        from spark_tensorflow_distributor import MirroredStrategyRunner
        
        
        # Taken from https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
        def train():
            import tensorflow_datasets as tfds
            import tensorflow as tf
            BUFFER_SIZE = 10000
            BATCH_SIZE = 64
        
            def make_datasets_unbatched():
                # Scaling MNIST data from (0, 255] to (0., 1.]
                def scale(image, label):
                    image = tf.cast(image, tf.float32)
                    image /= 255
                    return image, label
                datasets, info = tfds.load(
                    name='mnist',
                    with_info=True,
                    as_supervised=True,
                )
                return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)
        
            def build_and_compile_cnn_model():
                model = tf.keras.Sequential([
                    tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
                    tf.keras.layers.MaxPooling2D(),
                    tf.keras.layers.Flatten(),
                    tf.keras.layers.Dense(64, activation='relu'),
                    tf.keras.layers.Dense(10, activation='softmax'),
                ])
                model.compile(
                    loss=tf.keras.losses.sparse_categorical_crossentropy,
                    optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
                    metrics=['accuracy'],
                )
                return model
        
            GLOBAL_BATCH_SIZE = 64 * 8
            train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE).repeat()
            options = tf.data.Options()
            options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
            train_datasets = train_datasets.with_options(options)
            multi_worker_model = build_and_compile_cnn_model()
            multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
            return tf.config.experimental.list_physical_devices('GPU')
        
        MirroredStrategyRunner(num_slots=4).run(train)
        ```
        
        
Platform: UNKNOWN
Classifier: Development Status :: 1 - Planning
Classifier: Intended Audience :: Developers
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Natural Language :: English
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Topic :: Software Development :: Version Control :: Git
Requires-Python: >=3.6
Description-Content-Type: text/markdown
