Skip to content
Snippets Groups Projects
Commit feb1aa56 authored by lukas leufen's avatar lukas leufen
Browse files

first implementation of lrDecay, loss and lrCallback, no tests yet

parent 430cc664
No related branches found
No related tags found
2 merge requests!9new version v0.2.0,!7l_p_loss and lrdecay implementation
Pipeline #25683 passed
......@@ -2,7 +2,55 @@ __author__ = 'Lukas Leufen'
__date__ = '2019-10-21'
import logging
import keras
import keras.backend as K
import math
def to_list(arg):
if not isinstance(arg, list):
arg = [arg]
return arg
class Loss:
def l_p_loss(self, power):
def loss(y_true, y_pred):
return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
return loss
class lrDecay(keras.callbacks.History):
def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8):
super(lrDecay, self).__init__()
self.lr = {'lr': []}
self.base_lr = base_lr
self.drop = drop
self.epochs_drop = epochs_drop
def on_epoch_begin(self, epoch: int, logs=None):
if epoch > 0:
current_lr = self.base_lr * math.pow(self.drop, math.floor(1 + epoch) / self.epochs_drop)
else:
current_lr = self.base_lr
K.set_value(self.model.optimizer.lr, current_lr)
self.lr['lr'].append(current_lr)
logging.info(f"Set learning rate to {current_lr}")
return K.get_value(self.model.optimizer.lr)
class lrCallback(keras.callbacks.History):
def __init__(self):
super(lrCallback, self).__init__()
self.lr = None
def on_train_begin(self, logs=None):
self.lr = {}
def on_epoch_end(self, epoch, logs=None):
self.lr.append(self.model.optimizer.lr)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment