728x90

오늘은 간단하면서도, 막상 찾아보긴 힘든, 벡터를 여러개로 복사하는 것에 대해 글을 쓰겠다.

 

먼저 아래와 같은 변수가 있다고 가정

 

A = tf.random.uniform(shape=[64, 100, 256])

B = tf.random.uniform(shape=[64, 256])

 

하고 싶은 것은, A에 B를 concat 하고싶다고 가정하겠다.

 

이럴 경우 어떻게 해야하나?

 

먼저 Concat을 하려면, 차원이 맞아야 되는데, A의 2차원 즉 100에 해당하는 것이 B에는 없다.

 

이럴 경우 B에서 256짜리를 100번 복사해서 해야된다.

 

즉 A = [Batch, T, hidden], B = [Batch, hidden] 일 경우, B의 2번째 dimension에 B를 T번 복사해야된다.

 

최종으로 원하는것은 B = [Batch, T, hidden] 인 것이다.

 

tf.repeat 함수 (https://www.tensorflow.org/api_docs/python/tf/repeat)를 사용할 건데, 나는 현재 tensorflow 2.0 버전을 쓰고 있고, 여기에서 tf.repeat을 그대로 가져오면 없다고 오류 메시지가 뜬다.

 

tf.repeat 함수는 아래와 같은데,

 

tf.repeat(
    input
, repeats, axis=None, name=None
)

 

위에서 말 했듯이 나는 이게 안되므로,

tf.keras.backend.repeat(https://www.tensorflow.org/api_docs/python/tf/keras/backend/repeat) 함수를 사용 할 것이다.

 

tf.keras.backend.repeat 함수는 아래와 같은데,

 

tf.keras.backend.repeat(
    x
, n
)

 

tf.repeat과 다른 점은 딱히 없는 것 같다. 

 

B = [Batch, hidden] 인 것을, T=100으로 가정하여 B=[Batch, T, hidden]으로 만드는 코드는 아래와 같다.

 

import tensorflow as tf

batch_size = 100
seq_len = 100
hidden_size = 256

a = tf.random.uniform(shape=[batch_size, seq_len, hidden_size])
b = tf.random.uniform(shape=[batch_size, hidden_size])

print('a shape {} b shape {}'.format(a.shape, b.shape))

new_b = tf.keras.backend.repeat(b, n=100)

print('new_b shape {}'.format(new_b.shape))

concat_output = tf.concat(values=(a, new_b), axis=-1)

print('concat_output shape', concat_output.shape)

 

실행 결과는 아래와 같고,

이제 검증을 해야되는데, 100개로 늘렸을 경우 모든 100개가 기존 b의 hidden_size 내의 vector와 값이 같은가? 를 검증해야 된다.

 

아래와 같이 할 수 있다.

 

for i in range(len(new_b[0])):
    check = (b[0] == new_b[0][i])
    print('{}th check {}'.format(i, check))

 

이에 대한 결과는 모두 True

전체 코드

import tensorflow as tf

batch_size = 100
seq_len = 100
hidden_size = 256

a = tf.random.uniform(shape=[batch_size, seq_len, hidden_size])
b = tf.random.uniform(shape=[batch_size, hidden_size])

print('a shape {} b shape {}'.format(a.shape, b.shape))

new_b = tf.keras.backend.repeat(b, n=100)

print('new_b shape {}'.format(new_b.shape))

concat_output = tf.concat(values=(a, new_b), axis=-1)

print('concat_output shape', concat_output.shape)
 
for i in range(len(new_b[0])):
    check = (b[0] == new_b[0][i])
    print('{}th check {}'.format(i, check))

 

 

728x90

+ Recent posts