TFEstimatorBase¶
Abstract Base Class for Estimating Synthesizer Parameters using TensorFlow
- 
class 
spiegel.estimator.tf_estimator_base.TFEstimatorBase(inputShape, numOutputs, checkpointPath='', weightsPath='', loggers=[])¶ Bases:
spiegel.estimator.estimator_base.EstimatorBase- Parameters
 inputShape (tuple) – Shape of matrix that will be passed to model input
numOutputs – Number of outputs the model has
checkpointPath (string, optional) – If given, checkpoints will be saved to this location during training, defaults to “”
weightsPath (string, optional) – If given, model weights will be loaded from this file, defaults to “”
- 
addTestingData(input, output, batchSize=64)¶ Create a tf Dataset from input and output for model testing, batches data if desired
- Parameters
 input (np.array) – matrix of data to use as testing data
output (np.array) – matrix of data to use as ground truth for testing data
batchSize (int, optional) – If provided, will batch data into batches of this size, set to None or 0 to prevent batching. defaults to 64
- 
addTrainingData(input, output, batchSize=64, shuffleSize=None)¶ Create a tf Dataset from input and output, and shuffles / batches data for training
- Parameters
 input (np.array) – matrix of training data
output (np.array) – matrix of training data ground truth
batchSize (int, optional) – If provided, will batch data into batches of this size, set to None or 0 to prevent batching. defaults to 64
shuffleSize – If provided, will shuffle data with a buffer size of shuffleSize, defaults to None, so shuffling does not occur
- 
abstract 
buildModel()¶ Abstract method that should contain the model definition when implemented
- 
fit(epochs=1, callbacks=[], **kwargs)¶ Train model on for a fixed number of epochs on training data and validation data if it has been added to this estimator
- Parameters
 epochs (int, optional) – Number of epochs to train model on, defaults to 1
callbacks (list, optional) – List of callback functions for training, defaults to []
kwargs – Keyword args passed to model fit method. See Tensflow Docs.
- 
loadModelFromCheckpoint()¶ Load model weights from checkpoint
- 
loadWeights(filepath, **kwargs)¶ Load model weights from H5 or TensorFlow file
- Parameters
 filepath (string) – filepath to saved model weights
kwargs – optional keyword arguments passed to tf load_weights methods, see TensorFlow Docs.
- 
predict(input)¶ Run prediction on input
- Parameters
 input (np.array) – matrix of input data to run predictions on. Can be a single instance of data or a batch.
- 
static 
rootMeanSquaredError(labels, prediction)¶ Static method for calculating root mean squared error between predictions and targets
- Parameters
 labels (Tensor) – Matrix of ground truth labels
prediction (Tensor) – Matrix of predictions
- 
saveWeights(filepath, **kwargs)¶ Save model weights to a HDF5 or TensorFlow file.
- Parameters
 filepath (string) – filepath to save model weights. Using a file suffix of ‘.h5’ or ‘.keras’ will save in HDF5 format. Otherwise will save as TensorFlow.
kwargs – optional keyword arguments passed to tf save_weights method, see Tensflow Docs.