|
| 1 | +/* |
| 2 | + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +/** |
| 18 | + * BeamSearch OpenNMT |
| 19 | + **/ |
| 20 | + |
| 21 | +#pragma once |
| 22 | + |
| 23 | +#include <cuda_runtime.h> |
| 24 | +#include "fastertransformer/allocator.h" |
| 25 | +#include "fastertransformer/cuda/cuda_kernels.h" |
| 26 | +#include "fastertransformer/cuda/open_attention.h" |
| 27 | +#include "fastertransformer/cuda/decoding_kernel_check.h" |
| 28 | + |
| 29 | +namespace fastertransformer |
| 30 | +{ |
| 31 | + |
| 32 | +template <typename T> |
| 33 | +void BeamSearch_OpenNMT( |
| 34 | + float *log_probs, float *cum_log_probs, bool *finished, |
| 35 | + T **key_cache, T **value_cache, |
| 36 | + int *parent_ids, |
| 37 | + int *sequence_length, |
| 38 | + int *word_ids, |
| 39 | + int *ids, |
| 40 | + int *output_ids, |
| 41 | + const int batch_size, const int beam_width, |
| 42 | + const int vocab_size, const int hidden_dim, const int step, |
| 43 | + const int cache_size, const int decoder_layers, cudaStream_t stream, |
| 44 | + const int end_id, |
| 45 | + int *finished_count) |
| 46 | +{ |
| 47 | +#ifdef NDEBUG |
| 48 | + /* adding cum_log_probs to log_probs */ |
| 49 | + broadcast_kernelLauncher(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream); |
| 50 | +#else |
| 51 | + broadcast_kernelLauncher(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream); |
| 52 | + cudaDeviceSynchronize(); |
| 53 | + check_cuda_error(cudaGetLastError()); |
| 54 | + |
| 55 | + /* |
| 56 | + User can check the broadcast_kernel by broadcast_kernel_check. |
| 57 | + broadcast_kernel_check will compare the results of GPU and CPU. |
| 58 | + Note that broadcast_kernel_check contains broadcast_kernelLauncher and uses do not need to call it again. |
| 59 | + */ |
| 60 | + // broadcast_kernel_check(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream); |
| 61 | +#endif |
| 62 | + |
| 63 | +#ifdef NDEBUG |
| 64 | + /*Use two round kernels to pick the topK values for each batch */ |
| 65 | + topK(log_probs, ids, batch_size, beam_width, vocab_size, stream); |
| 66 | +#else |
| 67 | + topK(log_probs, ids, batch_size, beam_width, vocab_size, stream); |
| 68 | + cudaDeviceSynchronize(); |
| 69 | + check_cuda_error(cudaGetLastError()); |
| 70 | + |
| 71 | + /* |
| 72 | + User can check the topK by topK_check. |
| 73 | + topK_check will compare the results of GPU and CPU. |
| 74 | + Note that topK_check contains topK and uses do not need to call it again. |
| 75 | + */ |
| 76 | + // topK_kernel_check(log_probs, ids, batch_size, beam_width, vocab_size, stream); |
| 77 | +#endif |
| 78 | + |
| 79 | +#ifdef NDEBUG |
| 80 | + update(log_probs, cum_log_probs, ids, finished, |
| 81 | + parent_ids, sequence_length, word_ids, output_ids, |
| 82 | + batch_size, beam_width, vocab_size, stream, |
| 83 | + end_id, finished_count); |
| 84 | +#else |
| 85 | + update(log_probs, cum_log_probs, ids, finished, |
| 86 | + parent_ids, sequence_length, word_ids, output_ids, |
| 87 | + batch_size, beam_width, vocab_size, stream, |
| 88 | + end_id, finished_count); |
| 89 | + cudaDeviceSynchronize(); |
| 90 | + check_cuda_error(cudaGetLastError()); |
| 91 | + |
| 92 | + /* |
| 93 | + User can check the update by update_kernel_check. |
| 94 | + update_kernel_check will compare the results of GPU and CPU. |
| 95 | + Note that update_kernel_check contains update and uses do not need to call it again. |
| 96 | + */ |
| 97 | + // update_kernel_check(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids, |
| 98 | + // batch_size, beam_width, vocab_size, stream, end_id, finished_count); |
| 99 | +#endif |
| 100 | + |
| 101 | +#ifdef NDEBUG |
| 102 | + update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size, |
| 103 | + beam_width, hidden_dim, step, cache_size, |
| 104 | + decoder_layers, stream); |
| 105 | +#else |
| 106 | + update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size, |
| 107 | + beam_width, hidden_dim, step, cache_size, |
| 108 | + decoder_layers, stream); |
| 109 | + cudaDeviceSynchronize(); |
| 110 | + check_cuda_error(cudaGetLastError()); |
| 111 | + |
| 112 | + /* |
| 113 | + User can check the update_KV_cache by update_KV_cache_kernel_check. |
| 114 | + update_KV_cache_kernel_check will compare the results of GPU and CPU. |
| 115 | + Note that update_KV_cache_kernel_check contains update_KV_cache and uses do not need to call it again. |
| 116 | + */ |
| 117 | + // update_KV_cache_kernel_check(key_cache, value_cache, parent_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream); |
| 118 | +#endif |
| 119 | +} |
| 120 | + |
| 121 | +} // namespace fastertransformer |
0 commit comments