import types, random, copy, warnings, datetime, os
import tensorflow as tf, numpy as np
from keras.models import Model, Sequential
from keras.wrappers.scikit_learn import BaseWrapper
from keras.utils.np_utils import to_categorical
from keras.utils.generic_utils import has_arg
from keras.utils.vis_utils import plot_model
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
from keras.backend import tensorflow_backend as KTF
from keras import backend as K
from sklearn.utils.validation import check_is_fitted
from ..utils._vis import display_log
class BaseModel(BaseWrapper):
"""
Base class of wrappers.
Argments
----------
build_fn: callable function or class instance.
The `build_fn` should construct, compile and return a Keras model, which
will then be used to fit/predict. One of the following
three values could be passed to `build_fn`:
1. A function
2. An instance of a class that implements the `__call__` method
3. None. This means you implement a class that inherits from either
`KerasClassifier` or `KerasRegressor`. The `__call__` method of the
present class will then be treated as the default `build_fn`.
model_id: str or None, default=None.
This is used to log filename.
When model_id is None, this is generated by date time.
logroot_dir: str or None, default=None.
Log root dir. this model's logfile is saved under "`logroot_dir`/`model_id`/".
When logroot_dir is None, a logfile is not saved.
tb_display_url: str or None, default=None.
Tensorboard's url. When run `fit`, tensorboard and keras model plot is displayed on jupyter.
When tb_display_url is None, this is not displayed.
period: int or None, default=None.
Interval (number of epochs) between checkpoints for save the model (keras.callbacks.ModelCheckpoint's arg) .
When period is None, unless proper 'callback' argument is given at the time of running 'fit', this model is not saved.
patience: int or None, default=None.
Number of epochs with no improvement after which training will be stopped (keras.callbacks.EarlyStopping's arg).
When patience is None, unless proper 'callback' argument is given at the time of running 'fit', earlystopping is not used.
random_state: int or None, default=None.
The seed used by the random number generator.
(this seed is used in tensorflow, numpy and python random.)
reuse: bool, default=True.
When reuse is True, tensorflow session is not initialized per running `fit`.
**sk_params: dictonary.
model parameters & fitting parameters.
`sk_params` takes both model parameters and fitting parameters.
Legal modelparameters are the arguments of `build_fn`.
"""
def __init__(self, build_fn=None,
model_id=None, logroot_dir=None, tb_display_url=None,
period=None, patience=None, random_state=None,
reuse=True, **sk_params):
super().__init__(build_fn=build_fn, **sk_params)
self.model_id = model_id
self.logroot_dir = logroot_dir
self.tb_display_url = tb_display_url
self.random_state = random_state
self.reuse = reuse
self._cb_params = {"period":period, "patience":patience}
# Utils
def check_params(self, params):
"""
Checks for user typos in "params".
Arguments
----------
params: dictionary.
the parameters to be checked.
Raises
----------
ValueError: if any member of `params` is not a valid argument.
"""
legal_params_fns = [Sequential.fit, Sequential.predict,
Sequential.predict_classes, Sequential.evaluate,
Sequential.fit_generator, Sequential.predict_generator, Sequential.evaluate_generator,
Model.fit, Model.predict,
Model.evaluate,
Model.fit_generator, Model.predict_generator, Model.evaluate_generator]
if self.build_fn is None:
legal_params_fns.append(self.__call__)
elif (not isinstance(self.build_fn, types.FunctionType) and
not isinstance(self.build_fn, types.MethodType)):
legal_params_fns.append(self.build_fn.__call__)
else:
legal_params_fns.append(self.build_fn)
for params_name in params:
for fn in legal_params_fns:
if has_arg(fn, params_name):
break
else:
if params_name != 'nb_epoch':
raise ValueError('{} is not a legal parameter'.format(params_name))
def _check_params(self, params, legal_params_fns):
"""
Checks for user typos in "params" for common functions.
Arguments
----------
params: dictionary.
the parameters to be checked.
legal_params_fns: list of functions.
the functions to be checked
Raises
----------
ValueError: if any member of `params` is not a valid argument.
"""
for params_name in params:
for fn in legal_params_fns:
if has_arg(fn, params_name):
break
else:
raise ValueError('{} is not a legal parameter'.format(params_name))
def set_session_config(self, config):
"""
Set tensorfow session config.
Arguments
----------
config:
config for tensorflow session.
"""
self.session_config = config
def _filter_n_set_params(self, fn, **params):
"""
Filter `params` and return fn(args).
Arguments
----------
fn: function
function to retrun.
**params: dictonary.
Dictionary of parameter names mapped to their values.
Returns
----------
fn(expected args)
"""
args = {}
for params_name in params:
if has_arg(fn, params_name):
args[params_name] = params[params_name]
return fn(**args)
# Env setting
def _set_log(self):
"""
Set valiables about log.
"""
if self.model_id is None:
self.model_id = datetime.datetime.today().strftime("%Y%m%d_%H%M%S")
log_dir = os.path.join(self.logroot_dir, self.model_id)
if not os.path.exists(log_dir):
os.makedirs(log_dir, mode=0o777)
os.chmod(log_dir, mode=0o777)
self._cb_params["log_dir"] = log_dir
def _mk_cbks(self):
"""
Set keras callback classes.
"""
cbks = []
if self._cb_params["patience"] is not None:
cbks.append(self._filter_n_set_params(EarlyStopping, **self._cb_params))
if "log_dir" in self._cb_params:
plot_model(self.model, to_file=os.path.join(self._cb_params["log_dir"], str(self.model_id)+".png"), show_layer_names=True, show_shapes=True)
cbks.append(self._filter_n_set_params(TensorBoard, **self._cb_params))
if self._cb_params["period"] is not None:
if "filepath" not in self._cb_params:
self._cb_params["filepath"] = os.path.join(self._cb_params["log_dir"], "model.epoch{epoch:04d}.h5")
cbks.append(self._filter_n_set_params(ModelCheckpoint, **self._cb_params))
return cbks
def _init_session(self, random_state):
"""
Clear tensorflow session and set "global" session.
Arguments
----------
random_state:
The seed used by the random number generator.
(this seed is used in tensorflow, numpy and python random.)
"""
KTF.clear_session()
tf.set_random_seed(random_state)
np.random.seed(random_state)
random.seed(random_state)
if hasattr(self, "session_config"):
self.session = tf.Session(config=self.session_config)
else:
self.session = tf.Session("")
KTF.set_session(self.session)
# Modeling
def _build_model(self):
"""
Build model and set env.
"""
if not hasattr(self, "session"):
self._init_session(self.random_state)
if self.build_fn is None:
self.model = self.__call__(**self.filter_sk_params(self.__call__))
elif (not isinstance(self.build_fn, types.FunctionType) and
not isinstance(self.build_fn, types.MethodType)):
self.model = self.build_fn(
**self.filter_sk_params(self.build_fn.__call__))
else:
self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
if self.logroot_dir is not None:
if "log_dir" not in self._cb_params:
self._set_log()
def _ck_callbacks(self, validation_data, callbacks):
"""
Check callbacks.
Arguments
----------
validation_data: tuple of array-like `(x, y)` or None.
Data for validation.
When validation_data is None, monitor(keras.callbacks.ModelCheckpoint's valiable) is "loss".
Otherwise, monitor is "val_loss".
callbacks: list of keras callback classes or None.
When callbacks is None, set proper this model's callbacks.
Otherwise, set this arg as this model's callbacks.
"""
if hasattr(self, "callbacks"):
if callbacks is not None:
warnings.warn("This instance already is set callbacks. So input callbacks is ignored.")
else:
if callbacks is None:
if "monitor" not in self._cb_params:
if validation_data is None:
self._cb_params["monitor"] = "loss"
else:
self._cb_params["monitor"] = "val_loss"
callbacks = self._mk_cbks()
self.callbacks = callbacks
def fit(self, x, y, validation_data=None, callbacks=None, **kwargs):
"""
Constructs a new model with `build_fn` & fit the model to `(x, y)`.
Arguments
----------
x : array-like, shape `(n_samples, n_features)`
Training samples where n_samples in the number of samples
and n_features is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
True values for X.
validation_data: tuple of array-like `(x, y)` or None.
Data for validation.
When validation_data is None, skip validation.
callbacks: list of keras callback classes or None.
When callbacks is None, set proper this model's callbacks.
Otherwise, set this arg as this model's callbacks.
**kwargs: dictionary.
Legal arguments are the arguments of `fit (builded model method)`.
# Returns
history : object
details about the training history at each epoch.
"""
if (not self.reuse) or (not hasattr(self, "model")):
self._build_model()
self._ck_callbacks(validation_data=validation_data, callbacks=callbacks)
fit_args = copy.deepcopy(self.filter_sk_params(self.model.fit))
fit_args.update(kwargs)
if self.tb_display_url is not None:
display_log(path=self.logroot_dir, tb_display_url=self.tb_display_url, cr_model_id=self.model_id)
if "verbose" in fit_args:
fit_args["verbose"] = 0
history = self.model.fit(x, y, validation_data=validation_data, callbacks=self.callbacks, **fit_args)
return history
def evaluate(self, x, y, **kwargs):
"""
Retern model's loss.
Arguments
----------
x : array-like, shape `(n_samples, n_features)`
Training samples where n_samples in the number of samples
and n_features is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
True values for X.
**kwargs: dictionary.
Legal arguments are the arguments of `evaluate (builded model method)`.
# Returns
loss
model's loss
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.evaluate, kwargs)
return self.model.evaluate(x, y, **kwargs)
def fit_generator(self, generator, steps_per_epoch, validation_data=None, callbacks=None, **kwargs):
"""
Constructs a new model with `build_fn` & fit the model to genarator's output `(x, y)`.
Arguments
----------
generator: generator
The output of the generator must be either
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
steps_per_epoch: int.
Total number of steps (batches of samples).
validation_data: tuple of array-like `(x, y)` or None.
validation_data: This can be either
- A generator for the validation data
- A tuple (inputs, targets)
- A tuple (inputs, targets, sample_weights).
When validation_data is None, skip validation.
callbacks: list of keras callback classes or None.
When callbacks is None, set proper this model's callbacks.
Otherwise, set this arg as this model's callbacks.
**kwargs: dictionary.
Legal arguments are the arguments of `fit_generator (builded model method)`.
# Returns
history: object
details about the training history at each epoch.
"""
if (not self.reuse) or (not hasattr(self, "model")):
self._build_model()
self._ck_callbacks(validation_data=validation_data, callbacks=callbacks)
fit_args = copy.deepcopy(self.filter_sk_params(self.model.fit_generator))
fit_args.update(kwargs)
history = self.model.fit_generator(generator, steps_per_epoch, validation_data=validation_data,
callbacks=self.callbacks, **fit_args)
return history
def predict_generator(self, generator, steps, **kwargs):
"""
Generates predictions for the input samples from a data generator.
Arguments
----------
generator: generator
The output of the generator must be either
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
steps: int.
Total number of steps (batches of samples).
**kwargs: dictionary.
Legal arguments are the arguments of `predict_generator (builded model method)`.
Returns
----------
preds:
Predictions.
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.predict_generator, kwargs)
return self.model.predict_generator(generator=generator, steps=steps, **kwargs)
def evaluate_generator(self, generator, steps, **kwargs):
"""
Evaluates the model on a data generator.
Arguments
----------
generator: generator
The output of the generator must be either
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
steps: int.
Total number of steps (batches of samples).
**kwargs: dictionary.
Legal arguments are the arguments of `predict_generator (builded model method)`.
Returns
----------
loss
model's loss
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.evaluate_generator, kwargs)
return self.model.evaluate_generator(generator=generator, steps=steps, **kwargs)
def save(self, filepath, overwrite=True, include_optimizer=True):
"""
Save a model to a HDF5 file.
Arguments
----------
model:
Keras model instance to be saved.
filepath: str.
path where to save the model.
overwrite: bool, default=True.
Whether we should overwrite any existing
model at the target location, or instead
ask the user with a manual prompt.
include_optimizer: bool, default=True.
If True, save optimizer's state together.
"""
check_is_fitted(self, "model")
self.model.save(filepath=filepath, overwrite=overwrite, include_optimizer=include_optimizer)
[docs]class KerasClassifier(BaseModel):
"""
Implementation of the scikit-learn classifier API for Keras.
Argments
----------
build_fn: callable function or class instance.
The `build_fn` should construct, compile and return a Keras model, which
will then be used to fit/predict. One of the following
three values could be passed to `build_fn`:
1. A function
2. An instance of a class that implements the `__call__` method
3. None. This means you implement a class that inherits from either
`KerasClassifier` or `KerasRegressor`. The `__call__` method of the
present class will then be treated as the default `build_fn`.
model_id: str or None, default=None.
This is used to log filename.
When model_id is None, this is generated by date time.
logroot_dir: str or None, default=None.
Log root dir. this model's logfile is saved under "`logroot_dir`/`model_id`/".
When logroot_dir is None, a logfile is not saved.
tb_display_url: str or None, default=None.
Tensorboard's url. When run `fit`, tensorboard and keras model plot is displayed on jupyter.
When tb_display_url is None, this is not displayed.
period: int or None, default=None.
Interval (number of epochs) between checkpoints for save the model (keras.callbacks.ModelCheckpoint's arg) .
When period is None, unless proper 'callback' argument is given at the time of running 'fit', this model is not saved.
patience: int or None, default=None.
Number of epochs with no improvement after which training will be stopped (keras.callbacks.EarlyStopping's arg).
When patience is None, unless proper 'callback' argument is given at the time of running 'fit', earlystopping is not used.
random_state: int or None, default=None.
The seed used by the random number generator.
(this seed is used in tensorflow, numpy and python random.)
reuse: bool, default=True.
When reuse is True, tensorflow session is not initialized per running `fit`.
**sk_params: dictonary.
model parameters & fitting parameters.
`sk_params` takes both model parameters and fitting parameters.
Legal modelparameters are the arguments of `build_fn`.
"""
def _ck_y(self, y):
y = np.searchsorted(self.classes_, y)
if len(y.shape) != 2:
y = to_categorical(y)
return y
def _ck_classes(self, y):
y = np.array(y)
if len(y.shape) == 2 and y.shape[1] > 1:
self.classes_ = np.arange(y.shape[1])
elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
self.classes_ = np.unique(y)
y = np.searchsorted(self.classes_, y)
else:
raise ValueError('Invalid shape for y: ' + str(y.shape))
self.n_classes_ = len(self.classes_)
[docs] def fit(self, x, y, validation_data=None, callbacks=None, **kwargs):
"""
Constructs a new model with `build_fn` & fit the model to `(x, y)`.
Arguments
----------
x : array-like, shape `(n_samples, n_features)`
Training samples where n_samples in the number of samples
and n_features is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
True values for X.
validation_data: tuple of array-like `(x, y)` or None.
Data for validation.
When validation_data is None, skip validation.
callbacks: list of keras callback classes or None.
When callbacks is None, set proper this model's callbacks.
Otherwise, set this arg as this model's callbacks.
**kwargs: dictionary.
Legal arguments are the arguments of `fit (builded model method)`.
Returns
----------
history : object
details about the training history at each epoch.
"""
if not hasattr(self, "n_classes_"):
self._ck_classes(y)
return super().fit(x, self._ck_y(y), validation_data=validation_data, callbacks=callbacks, **kwargs)
[docs] def predict_proba(self, x, **kwargs):
"""
Returns class probability estimates for the given test data.
Arguments
----------
x: array-like, shape `(n_samples, n_features)`
Test samples where n_samples in the number of samples
and n_features is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments
of `predict (builded model method)`.
Returns
----------
proba: array-like, shape `(n_samples, n_outputs)`
Class probability estimates.
In the case of binary classification,
tp match the scikit-learn API,
will return an array of shape '(n_samples, 2)'
(instead of `(n_sample, 1)` as in Keras).
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.predict, kwargs)
probs = self.model.predict(x, **kwargs)
# check if binary classification
if probs.shape[1] == 1:
# first column is probability of class 0 and second is of class 1
probs = np.hstack([1 - probs, probs])
return probs
[docs] def predict(self, x, **kwargs):
"""
Returns the class predictions for the given test data.
Arguments
----------
x: array-like, shape `(n_samples, n_features)`
Test samples where n_samples in the number of samples
and n_features is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments
of `predict (builded model method)`.
Returns
----------
preds: array-like, shape `(n_samples,)`
Class predictions.
"""
check_is_fitted(self, "model")
probs = self.predict_proba(x, **kwargs)
return self.classes_[np.argmax(self.predict_proba(x), axis=1)]
[docs] def evaluate(self, x, y, **kwargs):
"""
Retern model's loss.
Arguments
----------
x : array-like, shape `(n_samples, n_features)`
Training samples where n_samples in the number of samples
and n_features is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
True values for X.
**kwargs: dictionary.
Legal arguments are the arguments of `evaluate (builded model method)`.
Returns
----------
loss
model's loss
"""
return super().evaluate(x, self._ck_y(y), **kwargs)
[docs] def fit_generator(self, generator, steps_per_epoch, classes, validation_data=None, callbacks=None, **kwargs):
"""
Constructs a new model with `build_fn` & fit the model to genarator's output `(x, y)`.
Arguments
----------
generator: generator
The output of the generator must be either
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
Note
This method, unlike `fit`, do not transform y (label to 1hot) .
steps_per_epoch: int.
Total number of steps (batches of samples).
validation_data: tuple of array-like `(x, y)` or None.
validation_data: This can be either
- A generator for the validation data
- A tuple (inputs, targets)
- A tuple (inputs, targets, sample_weights).
When validation_data is None, skip validation.
callbacks: list of keras callback classes or None.
When callbacks is None, set proper this model's callbacks.
Otherwise, set this arg as this model's callbacks.
**kwargs: dictionary.
Legal arguments are the arguments of `fit_generator (builded model method)`.
Returns
----------
history : object
details about the training history at each epoch.
"""
if not hasattr(self, "n_classes_"):
classes = np.array(classes)
if (len(classes.shape) == 2 and classes.shape[1] == 1) or len(classes.shape) == 1:
self.classes_ = np.array(classes)
self.n_classes_ = len(self.classes_)
else:
raise ValueError('Invalid shape for y: ' + str(y.shape))
return super().fit_generator(generator, steps_per_epoch, validation_data=validation_data, callbacks=callbacks, **kwargs)
[docs] def score(self, x, y, **kwargs):
"""
Returns the mean accuracy on the given test data and labels.
Arguments
----------
x: array-like, shape `(n_samples, n_features)`
Test samples where n_samples in the number of samples
and n_features is the number of features.
y: array-like, sto_categoricalhape `(n_samples,)` or `(n_samples, n_outputs)`
True labels for x.
**kwargs: dictionary arguments
Legal arguments are the arguments of `evaluate (builded model method)`.
Returns
----------
score: float
Mean accuracy of predictions on X wrt. y.
Raises
----------
ValueError:
If the underlying model isn't configured to
compute accuracy. You should pass `metrics=["accuracy"]` to
the `.compile()` method of the model.
"""
check_is_fitted(self, "model")
y = self._ck_y(y)
kwargs = self.filter_sk_params(self.model.evaluate, kwargs)
outputs = self.model.evaluate(x, y, **kwargs)
if not isinstance(outputs, list):
outputs = [outputs]
for name, output in zip(self.model.metrics_names, outputs):
if name == 'acc':
return output
raise ValueError('The model is not configured to compute accuracy. '
'You should pass `metrics=["accuracy"]` to '
'the `model.compile()` method.')
[docs]class KerasRegressor(BaseModel):
"""
Implementation of the scikit-learn regressor API for Keras.
Argments
----------
build_fn: callable function or class instance.
The `build_fn` should construct, compile and return a Keras model, which
will then be used to fit/predict. One of the following
three values could be passed to `build_fn`:
1. A function
2. An instance of a class that implements the `__call__` method
3. None. This means you implement a class that inherits from either
`KerasClassifier` or `KerasRegressor`. The `__call__` method of the
present class will then be treated as the default `build_fn`.
model_id: str or None, default=None.
This is used to log filename.
When model_id is None, this is generated by date time.
logroot_dir: str or None, default=None.
Log root dir. this model's logfile is saved under "`logroot_dir`/`model_id`/".
When logroot_dir is None, a logfile is not saved.
tb_display_url: str or None, default=None.
Tensorboard's url. When run `fit`, tensorboard and keras model plot is displayed on jupyter.
When tb_display_url is None, this is not displayed.
period: int or None, default=None.
Interval (number of epochs) between checkpoints for save the model (keras.callbacks.ModelCheckpoint's arg) .
When period is None, unless proper 'callback' argument is given at the time of running 'fit', this model is not saved.
patience: int or None, default=None.
Number of epochs with no improvement after which training will be stopped (keras.callbacks.EarlyStopping's arg).
When patience is None, unless proper 'callback' argument is given at the time of running 'fit', earlystopping is not used.
random_state: int or None, default=None.
The seed used by the random number generator.
(this seed is used in tensorflow, numpy and python random.)
reuse: bool, default=True.
When reuse is True, tensorflow session is not initialized per running `fit`.
**sk_params: dictonary.
model parameters & fitting parameters.
`sk_params` takes both model parameters and fitting parameters.
Legal modelparameters are the arguments of `build_fn`.
"""
[docs] def predict(self, x, **kwargs):
"""
Returns predictions for the given test data.
Arguments
----------
x: array-like, shape `(n_samples, n_features)`
Test samples where n_samples in the number of samples
and n_features is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments of `predict (builded model method)`.
Returns
----------
preds: array-like, shape `(n_samples,)`
Predictions.
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.predict, kwargs)
return np.squeeze(self.model.predict(x, **kwargs))
[docs] def score(self, x, y, **kwargs):
"""
Returns the mean loss on the given test data and labels.
Arguments
----------
x: array-like, shape `(n_samples, n_features)`
Test samples where n_samples in the number of samples
and n_features is the number of features.
y: array-like, shape `(n_samples,)`
True labels for X.
**kwargs: dictionary arguments
Legal arguments are the arguments of `evaluate (builded model method)`.
Returns
----------
score: float
Mean accuracy of predictions on X wrt. y.
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.evaluate, kwargs)
loss = self.model.evaluate(x, y, **kwargs)
if isinstance(loss, list):
return -loss[0]
return -loss
[docs]class KerasModel(BaseModel):
"""
Implementation of the scikit-learn API for Keras.
Argments
----------
build_fn: callable function or class instance.
The `build_fn` should construct, compile and return a Keras model, which
will then be used to fit/predict. One of the following
three values could be passed to `build_fn`:
1. A function
2. An instance of a class that implements the `__call__` method
3. None. This means you implement a class that inherits from either
`KerasClassifier` or `KerasRegressor`. The `__call__` method of the
present class will then be treated as the default `build_fn`.
model_id: str or None, default=None.
This is used to log filename.
When model_id is None, this is generated by date time.
logroot_dir: str or None, default=None.
Log root dir. this model's logfile is saved under "`logroot_dir`/`model_id`/".
When logroot_dir is None, a logfile is not saved.
tb_display_url: str or None, default=None.
Tensorboard's url. When run `fit`, tensorboard and keras model plot is displayed on jupyter.
When tb_display_url is None, this is not displayed.
period: int or None, default=None.
Interval (number of epochs) between checkpoints for save the model (keras.callbacks.ModelCheckpoint's arg) .
When period is None, unless proper 'callback' argument is given at the time of running 'fit', this model is not saved.
patience: int or None, default=None.
Number of epochs with no improvement after which training will be stopped (keras.callbacks.EarlyStopping's arg).
When patience is None, unless proper 'callback' argument is given at the time of running 'fit', earlystopping is not used.
random_state: int or None, default=None.
The seed used by the random number generator.
(this seed is used in tensorflow, numpy and python random.)
greater_is_better: bool, default=True.
Whether score_func is a score function (default), meaning high is good,
or a loss function, meaning low is good.
reuse: bool, default=True.
When reuse is True, tensorflow session is not initialized per running `fit`.
**sk_params: dictonary.
model parameters & fitting parameters.
`sk_params` takes both model parameters and fitting parameters.
Legal modelparameters are the arguments of `build_fn`.
"""
def __init__(self,
build_fn=None,
model_id=None, logroot_dir=None, period=None, patience=None, random_state=None,
reuse=True, greater_is_better=True, **sk_params):
# TODO: review method of inputting `greater_is_better`.
# `greater_is_better` is used in score method and score method is needed to sklearn's search class.
super().__init__(build_fn=build_fn,
model_id=model_id, logroot_dir=logroot_dir, period=period, patience=patience, random_state=random_state,
reuse=reuse, **sk_params)
self.greater_is_better = greater_is_better
[docs] def predict(self, x, **kwargs):
"""
Returns predictions for the given test data.
Arguments
----------
x: array-like, shape `(n_samples, n_features)`
Test samples where n_samples in the number of samples
and n_features is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments of `predict (builded model method)`.
Returns
----------
preds: array-like, shape `(n_samples,)`
Predictions.
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.predict, kwargs)
return self.model.predict(x, **kwargs)
[docs] def score(self, x, y, **kwargs):
"""
Returns the mean loss on the given test data and labels.
Arguments
----------
x: array-like, shape `(n_samples, n_features)`
Test samples where n_samples in the number of samples
and n_features is the number of features.
y: array-like, shape `(n_samples,)`
True labels for X.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.evaluate`.
Returns
----------
score
the mean loss on the given test data and labels.
"""
check_is_fitted(self, "model")
kwargs = self.filter_sk_params(self.model.evaluate, kwargs)
loss = self.model.evaluate(x, y, **kwargs)
if not self.greater_is_better:
loss = -1*loss
if isinstance(loss, list):
return loss[0]
return loss