diff --git a/test/sampling_benchmark.cpp b/test/sampling_benchmark.cpp index 9ba42b998..0c3d8e113 100644 --- a/test/sampling_benchmark.cpp +++ b/test/sampling_benchmark.cpp @@ -43,6 +43,7 @@ struct SamplingBenchmark { int num_iter = 1000; auto logits = params->p_device->Allocate(static_cast(config.model.vocab_size) * batch_size_); + auto test_start = std::chrono::high_resolution_clock::now(); for (int i = 0; i < num_iter; i++) { auto generator = Generators::CreateGenerator(*model, *params); @@ -57,6 +58,10 @@ struct SamplingBenchmark { auto stop = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(stop - start); total_time += duration.count(); + if (std::chrono::duration_cast(stop - start) > std::chrono::minutes(1)) { + std::cout << Generators::SGR::Bg_Red << " ABORTING " << Generators::SGR::Reset << " loop due to slow performance(took more than 1 minute) on iteration " << i << std::endl; + break; + } } double average_time = total_time / double(num_iter); std::cout << "Average time taken: " << average_time << " microseconds" << std::endl;