TFEstimatorBase¶
Abstract Base Class for Estimating Synthesizer Parameters using TensorFlow
-
class
spiegel.estimator.tf_estimator_base.
TFEstimatorBase
(inputShape, numOutputs, checkpointPath='', weightsPath='', loggers=[])¶ Bases:
spiegel.estimator.estimator_base.EstimatorBase
- Parameters
inputShape (tuple) – Shape of matrix that will be passed to model input
numOutputs – Number of outputs the model has
checkpointPath (string, optional) – If given, checkpoints will be saved to this location during training, defaults to “”
weightsPath (string, optional) – If given, model weights will be loaded from this file, defaults to “”
-
addTestingData
(input, output, batchSize=64)¶ Create a tf Dataset from input and output for model testing, batches data if desired
- Parameters
input (np.array) – matrix of data to use as testing data
output (np.array) – matrix of data to use as ground truth for testing data
batchSize (int, optional) – If provided, will batch data into batches of this size, set to None or 0 to prevent batching. defaults to 64
-
addTrainingData
(input, output, batchSize=64, shuffleSize=None)¶ Create a tf Dataset from input and output, and shuffles / batches data for training
- Parameters
input (np.array) – matrix of training data
output (np.array) – matrix of training data ground truth
batchSize (int, optional) – If provided, will batch data into batches of this size, set to None or 0 to prevent batching. defaults to 64
shuffleSize – If provided, will shuffle data with a buffer size of shuffleSize, defaults to None, so shuffling does not occur
-
abstract
buildModel
()¶ Abstract method that should contain the model definition when implemented
-
fit
(epochs=1, callbacks=[], **kwargs)¶ Train model on for a fixed number of epochs on training data and validation data if it has been added to this estimator
- Parameters
epochs (int, optional) – Number of epochs to train model on, defaults to 1
callbacks (list, optional) – List of callback functions for training, defaults to []
kwargs – Keyword args passed to model fit method. See Tensflow Docs.
-
loadModelFromCheckpoint
()¶ Load model weights from checkpoint
-
loadWeights
(filepath, **kwargs)¶ Load model weights from H5 or TensorFlow file
- Parameters
filepath (string) – filepath to saved model weights
kwargs – optional keyword arguments passed to tf load_weights methods, see TensorFlow Docs.
-
predict
(input)¶ Run prediction on input
- Parameters
input (np.array) – matrix of input data to run predictions on. Can be a single instance of data or a batch.
-
static
rootMeanSquaredError
(labels, prediction)¶ Static method for calculating root mean squared error between predictions and targets
- Parameters
labels (Tensor) – Matrix of ground truth labels
prediction (Tensor) – Matrix of predictions
-
saveWeights
(filepath, **kwargs)¶ Save model weights to a HDF5 or TensorFlow file.
- Parameters
filepath (string) – filepath to save model weights. Using a file suffix of ‘.h5’ or ‘.keras’ will save in HDF5 format. Otherwise will save as TensorFlow.
kwargs – optional keyword arguments passed to tf save_weights method, see Tensflow Docs.