TFEstimatorBase

Abstract Base Class for Estimating Synthesizer Parameters using TensorFlow

class spiegel.estimator.tf_estimator_base.TFEstimatorBase(inputShape, numOutputs, checkpointPath='', weightsPath='', loggers=[])

Bases: spiegel.estimator.estimator_base.EstimatorBase

Parameters
  • inputShape (tuple) – Shape of matrix that will be passed to model input

  • numOutputs – Number of outputs the model has

  • checkpointPath (string, optional) – If given, checkpoints will be saved to this location during training, defaults to “”

  • weightsPath (string, optional) – If given, model weights will be loaded from this file, defaults to “”

addTestingData(input, output, batchSize=64)

Create a tf Dataset from input and output for model testing, batches data if desired

Parameters
  • input (np.array) – matrix of data to use as testing data

  • output (np.array) – matrix of data to use as ground truth for testing data

  • batchSize (int, optional) – If provided, will batch data into batches of this size, set to None or 0 to prevent batching. defaults to 64

addTrainingData(input, output, batchSize=64, shuffleSize=None)

Create a tf Dataset from input and output, and shuffles / batches data for training

Parameters
  • input (np.array) – matrix of training data

  • output (np.array) – matrix of training data ground truth

  • batchSize (int, optional) – If provided, will batch data into batches of this size, set to None or 0 to prevent batching. defaults to 64

  • shuffleSize – If provided, will shuffle data with a buffer size of shuffleSize, defaults to None, so shuffling does not occur

abstract buildModel()

Abstract method that should contain the model definition when implemented

fit(epochs=1, callbacks=[], **kwargs)

Train model on for a fixed number of epochs on training data and validation data if it has been added to this estimator

Parameters
  • epochs (int, optional) – Number of epochs to train model on, defaults to 1

  • callbacks (list, optional) – List of callback functions for training, defaults to []

  • kwargs – Keyword args passed to model fit method. See Tensflow Docs.

loadModelFromCheckpoint()

Load model weights from checkpoint

loadWeights(filepath, **kwargs)

Load model weights from H5 or TensorFlow file

Parameters
  • filepath (string) – filepath to saved model weights

  • kwargs – optional keyword arguments passed to tf load_weights methods, see TensorFlow Docs.

predict(input)

Run prediction on input

Parameters

input (np.array) – matrix of input data to run predictions on. Can be a single instance of data or a batch.

static rootMeanSquaredError(labels, prediction)

Static method for calculating root mean squared error between predictions and targets

Parameters
  • labels (Tensor) – Matrix of ground truth labels

  • prediction (Tensor) – Matrix of predictions

saveWeights(filepath, **kwargs)

Save model weights to a HDF5 or TensorFlow file.

Parameters
  • filepath (string) – filepath to save model weights. Using a file suffix of ‘.h5’ or ‘.keras’ will save in HDF5 format. Otherwise will save as TensorFlow.

  • kwargs – optional keyword arguments passed to tf save_weights method, see Tensflow Docs.