Python/Keras에서 BN이 갱신되지 않는 문제

환경은 Python/ Tensorflow(1.13)/ Keras

요샌 보통 DNN의 모델을 Keras를 이용하고, Back Propagation은 Tensorflow로 손실함수를 조작하는 식으로 쓰는데,

예를들면,

Functional API

v_input = Input(batch_shape=(None, inputshape))
net = BatchNormalization()(net)
net = Dense(512)(net)
net = Relu()(net)
net = BatchNormalization()(net)

output = Dense(outputshape,)(net)
model = Model(inputs=v_input, outputs=output)

input-> BN -> Fully-connected(512)->relu->BN->output

Keras대응으로

model.compile(optimizer='rmsprop',
loss='mean_squared_error',
metrics=['accuracy'])
model.fit(X,Y)

이런식으로 해주면 별 문제가 없다.

학습을 진행하면서 BN의 moving_mean, moving_std도 잘 갱신될 것이다.

그러나 똑같은 방법을

model._make_predict_function()
placeholder_input = tf.placeholder(tf.float32, shape=(None,  inputshape))
placeholder_output = tf.placeholder(tf.float32, shape=(None,  outputshape))
value_output = model([placeholder_input])
loss = tf.reduce_mean(tf.square(placeholder_output-value_output))

optimizer = tf.train.RMSPropOptimizer()
minimize = optimizer.minimize(loss)

session.run(minimize,feed_dict={placeholder_input:X, placeholder_output:Y})

이렇게 해버리면 BN이 작동을 하질 않는다

(각 배치의 평균과 분산은 계산하지만 Layer의 momentum계산을 하지 않는다)


tf에서 조작하는 이유

  1. Keras모델이 직관성이 더 좋다.
  2. Keras에선 Custom Loss를 구성하기가 너무 불편하다(y랑 y_pred에 관한 함수로만 받는다.)
  3. Custom하기 편하다.

원인

보통 BN은 세션 돌리기전 계산이 되도록 해야되는데, global control_dependencies에 moving_mean/std를 갱신하는 키가 들어있지 않다.

그 이유는 Keras의 BN은 ‘레이어가 생성될 때’가 아닌 ‘입력이 들어왔을 때’에 키가 생성되기 때문에

모델를 계획하고 minimize하는걸 그래프를 만들었을때에 키가 생성되지 않아서 graph의 control_dependencies에 추가를 할 수가 없다.

그 상태로 minimize를 계속한다고 해도 moving_mean/var는 평균0/표준편차1인 그대로다…

(일시적인)해결법

keras의 normalization.py를 보면

195~부터

self.add_update([K.moving_average_update(self.moving_mean,
                                         mean,
                                         self.momentum),
                 K.moving_average_update(self.moving_variance,
                                         variance,
                                         self.momentum)],
                inputs)
                #이 부분이 input이 들어왔을때 ops를 생성시키는 condition을 나타낸다

이걸

self.add_update([K.moving_average_update(self.moving_mean,
                                         mean,
                                         self.momentum),
                 K.moving_average_update(self.moving_variance,
                                         variance,
                                         self.momentum)],
                inputs)
                        None)		#inputs)
#이런식으로 바꿔준다

그러면 레이어or모델을 작성하였을때에 model.updates에 moving_mean/std를 갱신하는 텐서가 나오게 되고 추적할 수 있게 된다.

update_ops = []
for update_op in model.updates:
    if "model" in update_op.name:
        #혹 functional api가 아닌 sequential을 사용한다면
        #model이 아닌 sequence라는 문자를 추적하면 될것이다.
        update_ops.append(update_op)
        #아마
        #batch_normalization/moving_average_update 비슷한 텐서가 나올것이다.
                
with tf.control_dependencies(update_ops):
    minimize = optimizer.minimize(modified_loss)
    

이런식으로 갱신할 수 있게 된다.

레이어는 build된 뒤 정상적으로 build되었는지 확인하기 위해 call을 호출하는데 그때에 생성되기 때문에 추가 할 수 있게 된다.

reference

https://github.com/tensorflow/tensorflow/issues/23873