Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions bigcodebench/provider/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bigcodebench.gen.util.openai_request import make_auto_request
from bigcodebench.provider.utility import make_raw_chat_prompt
from bigcodebench.provider.base import DecoderBase
from bigcodebench.provider.utility import concurrent_call
from bigcodebench.provider.utility import concurrent_call, concurrent_map

class OpenAIChatDecoder(DecoderBase):
def __init__(self, name: str, base_url=None, reasoning_effort="medium", **kwargs) -> None:
Expand Down Expand Up @@ -38,8 +38,8 @@ def _codegen_api_batch(self, messages: List[str], num_samples: int) -> List[str]
api_key=os.getenv("OPENAI_API_KEY", "none"), base_url=self.base_url
)

all_outputs = []
for message in tqdm(messages):
# Helper function to process a single message with its index
def process_message(index, message):
ret = make_auto_request(
client,
message=message,
Expand All @@ -52,7 +52,11 @@ def _codegen_api_batch(self, messages: List[str], num_samples: int) -> List[str]
outputs = []
for item in ret.choices:
outputs.append(item.message.content)
all_outputs.append(outputs)
return outputs

# Process messages in parallel using utility function
all_outputs = concurrent_map(messages, process_message)

return all_outputs

def _codegen_batch_via_concurrency(self, messages: List[str], num_samples: int) -> List[str]:
Expand Down
32 changes: 30 additions & 2 deletions bigcodebench/provider/utility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List
from transformers import AutoTokenizer
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

EOS = [
"<|endoftext|>",
Expand Down Expand Up @@ -80,4 +81,31 @@ def make_raw_chat_prompt(
def concurrent_call(n, callback, /, *args, **kwargs):
with ThreadPoolExecutor(max_workers=n) as executor:
futures = [executor.submit(callback, *args, **kwargs) for _ in range(n)]
return [future.result() for future in futures]
return [future.result() for future in futures]


# Parallel processing utility for processing multiple different items concurrently.
# Unlike concurrent_call which runs the same operation n times, concurrent_map
# processes n different items in parallel (similar to map() but concurrent).
# Automatically tracks indices to preserve output order and shows progress with tqdm.
def concurrent_map(items, callback, /, *args, **kwargs):
"""
Process a list of items in parallel using a callback function.

Args:
items: List of items to process
callback: Function that takes (index, item, *args, **kwargs) and returns a result
*args, **kwargs: Additional arguments passed to callback

Returns:
List of results in the same order as input items
"""
with ThreadPoolExecutor() as executor:
futures = {executor.submit(callback, i, item, *args, **kwargs): i
for i, item in enumerate(items)}
results = [None] * len(items)
# Collect results as they complete with progress bar
for future in tqdm(as_completed(futures), total=len(items)):
index = futures[future]
results[index] = future.result()
return results