Skip to content

Quick Start#

Using maggy for Distributed Training works as follows:

  • Optionally, define a model generator object, similarly to what is done for Ablation Studies.
    class MyModel(tf.keras.Model):
        def __init__(self, ...):
        def call(self, ...):
  • Optionally, define your train and test datasets, these will be sharded by Maggy.
    # Extract the data
    (x_train, y_train),(x_test, y_test) = split_dataset(dataset)
    # Do some preprocessing operations
  • Define a training function containing the training logic.

    def training_function(model, train_set, test_set, hparams):
        #training and testing logic

  • Create the configuration object and run the optimization.

    config = TfDistributedConfig(name="tf_test", 
                                 train_set=(x_train, y_train), 
                                 test_set=(x_test, y_test),
    experiment.lagom(train_fn=training_function, config=config)
    There are many parameters for the configuration object:

    • model: A tf.keras.Model superclass or list of them. Note that this has to be the class itself, not an instance.
    • train_set: The training set for the training function. If you want to load the set inside the training function, this can be disregarded.
    • test_set: The test set for the training function. If you want to load the set inside the training function, this can be disregarded.
    • process_data: The function for processing the data
    • hparams: model parameters that should be used during model initialization. Primarily used to give an interface for hp optimization.
    • name: Experiment name.
    • hb_interval: Heartbeat interval with which the server is polling.
    • description: A description of the experiment.