|
| 1 | +import time |
| 2 | +import os |
| 3 | +import psutil |
| 4 | +import tensorflow as tf |
| 5 | +import tensorlayer as tl |
| 6 | +from exp_config import random_input_generator, MONITOR_INTERVAL, NUM_ITERS, BATCH_SIZE, LERANING_RATE |
| 7 | + |
| 8 | +gpus = tf.config.experimental.list_physical_devices('GPU') |
| 9 | +if gpus: |
| 10 | + for gpu in gpus: |
| 11 | + tf.config.experimental.set_memory_growth(gpu, True) |
| 12 | + |
| 13 | +tl.logging.set_verbosity(tl.logging.DEBUG) |
| 14 | + |
| 15 | +# get the whole model |
| 16 | +vgg = tl.models.vgg16(mode='static') |
| 17 | + |
| 18 | +# system monitor |
| 19 | +info = psutil.virtual_memory() |
| 20 | +monitor_interval = MONITOR_INTERVAL |
| 21 | +avg_mem_usage = 0 |
| 22 | +max_mem_usage = 0 |
| 23 | +count = 0 |
| 24 | +total_time = 0 |
| 25 | + |
| 26 | +# training setting |
| 27 | +num_iter = NUM_ITERS |
| 28 | +batch_size = BATCH_SIZE |
| 29 | +train_weights = vgg.trainable_weights |
| 30 | +optimizer = tf.optimizers.Adam(learning_rate=LERANING_RATE) |
| 31 | +loss_object = tl.cost.cross_entropy |
| 32 | + |
| 33 | +# data generator |
| 34 | +gen = random_input_generator(num_iter, batch_size) |
| 35 | + |
| 36 | + |
| 37 | +# training function |
| 38 | +def train_step(x_batch, y_batch): |
| 39 | + # forward + backward |
| 40 | + with tf.GradientTape() as tape: |
| 41 | + ## compute outputs |
| 42 | + _logits = vgg(x_batch) |
| 43 | + ## compute loss and update model |
| 44 | + _loss = loss_object(_logits, y_batch) |
| 45 | + |
| 46 | + grad = tape.gradient(_loss, train_weights) |
| 47 | + optimizer.apply_gradients(zip(grad, train_weights)) |
| 48 | + return _loss |
| 49 | + |
| 50 | + |
| 51 | +# begin training |
| 52 | +vgg.train() |
| 53 | + |
| 54 | +for idx, data in enumerate(gen): |
| 55 | + start_time = time.time() |
| 56 | + |
| 57 | + loss = train_step(data[0], data[1]) |
| 58 | + |
| 59 | + end_time = time.time() |
| 60 | + consume_time = end_time - start_time |
| 61 | + total_time += consume_time |
| 62 | + |
| 63 | + if idx % monitor_interval == 0: |
| 64 | + cur_usage = psutil.Process(os.getpid()).memory_info().rss |
| 65 | + max_mem_usage = max(cur_usage, max_mem_usage) |
| 66 | + avg_mem_usage += cur_usage |
| 67 | + count += 1 |
| 68 | + tl.logging.info( |
| 69 | + "[*] {} iteration: memory usage {:.2f}MB, consume time {:.4f}s, loss {:.4f}".format( |
| 70 | + idx, cur_usage / (1024 * 1024), consume_time, loss |
| 71 | + ) |
| 72 | + ) |
| 73 | + |
| 74 | +print('consumed time:', total_time) |
| 75 | + |
| 76 | +avg_mem_usage = avg_mem_usage / count / (1024 * 1024) |
| 77 | +max_mem_usage = max_mem_usage / (1024 * 1024) |
| 78 | +print('average memory usage: {:.2f}MB'.format(avg_mem_usage)) |
| 79 | +print('maximum memory usage: {:.2f}MB'.format(max_mem_usage)) |
0 commit comments