diff --git a/bigcodebench/provider/openai.py b/bigcodebench/provider/openai.py index ff1459f..caf52a6 100644 --- a/bigcodebench/provider/openai.py +++ b/bigcodebench/provider/openai.py @@ -6,7 +6,7 @@ from bigcodebench.gen.util.openai_request import make_auto_request from bigcodebench.provider.utility import make_raw_chat_prompt from bigcodebench.provider.base import DecoderBase -from bigcodebench.provider.utility import concurrent_call +from bigcodebench.provider.utility import concurrent_call, concurrent_map class OpenAIChatDecoder(DecoderBase): def __init__(self, name: str, base_url=None, reasoning_effort="medium", **kwargs) -> None: @@ -38,8 +38,8 @@ def _codegen_api_batch(self, messages: List[str], num_samples: int) -> List[str] api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=self.base_url ) - all_outputs = [] - for message in tqdm(messages): + # Helper function to process a single message with its index + def process_message(index, message): ret = make_auto_request( client, message=message, @@ -52,7 +52,11 @@ def _codegen_api_batch(self, messages: List[str], num_samples: int) -> List[str] outputs = [] for item in ret.choices: outputs.append(item.message.content) - all_outputs.append(outputs) + return outputs + + # Process messages in parallel using utility function + all_outputs = concurrent_map(messages, process_message) + return all_outputs def _codegen_batch_via_concurrency(self, messages: List[str], num_samples: int) -> List[str]: diff --git a/bigcodebench/provider/utility.py b/bigcodebench/provider/utility.py index ad09159..cc4b2ce 100644 --- a/bigcodebench/provider/utility.py +++ b/bigcodebench/provider/utility.py @@ -1,6 +1,7 @@ from typing import List from transformers import AutoTokenizer -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm EOS = [ "<|endoftext|>", @@ -80,4 +81,31 @@ def make_raw_chat_prompt( def concurrent_call(n, callback, /, *args, **kwargs): with ThreadPoolExecutor(max_workers=n) as executor: futures = [executor.submit(callback, *args, **kwargs) for _ in range(n)] - return [future.result() for future in futures] \ No newline at end of file + return [future.result() for future in futures] + + +# Parallel processing utility for processing multiple different items concurrently. +# Unlike concurrent_call which runs the same operation n times, concurrent_map +# processes n different items in parallel (similar to map() but concurrent). +# Automatically tracks indices to preserve output order and shows progress with tqdm. +def concurrent_map(items, callback, /, *args, **kwargs): + """ + Process a list of items in parallel using a callback function. + + Args: + items: List of items to process + callback: Function that takes (index, item, *args, **kwargs) and returns a result + *args, **kwargs: Additional arguments passed to callback + + Returns: + List of results in the same order as input items + """ + with ThreadPoolExecutor() as executor: + futures = {executor.submit(callback, i, item, *args, **kwargs): i + for i, item in enumerate(items)} + results = [None] * len(items) + # Collect results as they complete with progress bar + for future in tqdm(as_completed(futures), total=len(items)): + index = futures[future] + results[index] = future.result() + return results \ No newline at end of file