"""
Conversation generator for WizardSData.
"""
import os
import json
import uuid
from typing import Dict, Any, List, Optional
import openai
from .config import config
from .templates import render_prompt_from_path
[docs]
def initialize_apis(api_key: Optional[str] = None) -> tuple:
"""
Initialize API clients for OpenAI.
Args:
api_key: Optional API key to use. If None, uses the one from config.
Returns:
Tuple with (client_api, advisor_api)
"""
key = api_key or config.get('API_KEY')
if not key:
raise ValueError("API_KEY is required but not provided")
client_api = openai.Client(api_key=key)
advisor_api = openai.Client(api_key=key)
return client_api, advisor_api
[docs]
def get_model_response(api_client, model: str, messages: List[Dict[str, str]],
temperature: float, top_p: float, frequency_penalty: float,
max_tokens: Optional[int]) -> str:
"""
Get response from a model using the OpenAI API.
Args:
api_client: OpenAI client instance.
model: Model name to use.
messages: List of message dictionaries (role and content).
temperature: Temperature setting for response randomness.
top_p: Top p setting for response diversity.
frequency_penalty: Frequency penalty to apply.
max_tokens: Maximum number of tokens to generate.
Returns:
The generated response as a string.
"""
params = {
'model': model,
'messages': messages
}
# Add optional parameters if provided
if temperature is not None:
params['temperature'] = temperature
if top_p is not None:
params['top_p'] = top_p
if frequency_penalty is not None:
params['frequency_penalty'] = frequency_penalty
if max_tokens is not None:
params['max_tokens'] = max_tokens
response = api_client.chat.completions.create(**params)
return response.choices[0].message.content.strip()
[docs]
def save_conversation(conversations: List[Dict[str, Any]], file_path: str) -> bool:
"""
Save the conversation dataset to a JSON file.
Args:
conversations: List of conversation dictionaries.
file_path: Path to save the conversations.
Returns:
True if successful, False otherwise.
"""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w') as file:
json.dump(conversations, file, indent=4)
return True
except Exception as e:
print(f"Error saving conversations: {str(e)}")
return False
[docs]
def initiate_conversation(client_prompt: str, advisor_prompt: str, financial_goal: str,
client_api, advisor_api, max_questions: int) -> List[Dict[str, Any]]:
"""
Initiate a continuous conversation between client and advisor models.
Args:
client_prompt: System prompt for the client model.
advisor_prompt: System prompt for the advisor model.
financial_goal: Financial goal for this conversation.
client_api: OpenAI client for the client model.
advisor_api: OpenAI client for the advisor model.
max_questions: Maximum number of interactions.
Returns:
List of dictionaries representing the conversation.
"""
# Get configuration parameters
model_client = config.get('model_client')
temperature_client = config.get('temperature_client')
top_p_client = config.get('top_p_client')
frequency_penalty_client = config.get('frequency_penalty_client')
max_tokens_client = config.get('max_tokens_client')
model_advisor = config.get('model_advisor')
temperature_advisor = config.get('temperature_advisor')
top_p_advisor = config.get('top_p_advisor')
frequency_penalty_advisor = config.get('frequency_penalty_advisor')
max_tokens_advisor = config.get('max_tokens_advisor')
# Generate unique conversation ID
conversation_id = str(uuid.uuid4())
conversation_dataset = []
sequence = 0
# Initialize conversations with system prompts
client_conversation = [{"role": "system", "content": client_prompt}]
advisor_conversation = [{"role": "system", "content": advisor_prompt}]
for _ in range(round(max_questions * 2.1)):
# Client response
client_response = get_model_response(
client_api,
model=model_client,
messages=client_conversation,
temperature=temperature_client,
top_p=top_p_client,
frequency_penalty=frequency_penalty_client,
max_tokens=max_tokens_client
)
print("client: " + client_response)
# Add the client response to the dataset BEFORE checking for [END]
conversation_dataset.append({
"id_conversation": conversation_id,
"topic": financial_goal,
"sequence": sequence,
"rol1": client_response.replace("[END]", "").strip(),
"rol2": "" # Placeholder for advisor response
})
# Now check for [END] after adding to dataset
if "[END]" in client_response:
break
client_conversation.append({"role": "assistant", "content": client_response})
advisor_conversation.append({"role": "user", "content": client_response})
sequence += 1
# Advisor response
advisor_response = get_model_response(
advisor_api,
model=model_advisor,
messages=advisor_conversation,
temperature=temperature_advisor,
top_p=top_p_advisor,
frequency_penalty=frequency_penalty_advisor,
max_tokens=max_tokens_advisor
)
print("advisor: " + advisor_response)
# Update the last entry in the dataset with the advisor response
if conversation_dataset:
conversation_dataset[-1]["rol2"] = advisor_response.replace("[END]", "").strip()
# Now check for [END] after adding to dataset
if "[END]" in advisor_response:
break
advisor_conversation.append({"role": "assistant", "content": advisor_response})
client_conversation.append({"role": "user", "content": advisor_response})
sequence += 1
return conversation_dataset
[docs]
def start_generation() -> bool:
"""
Start generating conversations between roles based on the current configuration.
This function orchestrates the entire conversation generation process. It:
1. Validates the current configuration
2. Initializes the API clients
3. Loads conversation profiles from the configured file
4. Renders prompts for each profile
5. Generates conversations for each profile using the appropriate models
6. Saves all conversations to the configured output file
The function relies on the global configuration instance having all necessary
parameters properly set.
Returns
-------
bool
True if the generation process completed successfully, False otherwise.
Returns False in the following cases:
- Invalid configuration (missing parameters)
- No profiles found in the profile file
- Failure to render prompts
- Error during API calls
- Failure to save conversations
Notes
-----
Before calling this function, ensure all required configuration parameters are set:
- API_KEY: Required for API access
- template_client_prompt: Path to client prompt template
- template_advisor_prompt: Path to advisor prompt template
- file_profiles: Path to JSON file containing conversation profiles
- file_output: Path where generated conversations will be saved
- model_client: Model configuration for client responses
- model_advisor: Model configuration for advisor responses
The function will log progress and error information to standard output.
Examples
--------
>>> from wizardsdata.config import set_config
>>> from wizardsdata.generator import start_generation
>>>
>>> # Set up configuration
>>> errors = set_config(
... API_KEY="your-api-key",
... template_client_prompt="templates/client.txt",
... template_advisor_prompt="templates/advisor.txt",
... file_profiles="data/profiles.json",
... file_output="output/conversations.json",
... model_client="gpt-4",
... model_advisor="gpt-4"
... )
>>>
>>> if not errors:
... # Start generation
... success = start_generation()
... if success:
... print("Generation completed successfully!")
... else:
... print("Generation failed.")
... else:
... print(f"Invalid configuration: {errors}")
"""
# Validate configuration
if not config.is_valid():
missing = config.validate()
print(f"Missing or invalid configuration parameters: {', '.join(missing)}")
return False
try:
# Initialize APIs
client_api, advisor_api = initialize_apis()
# Load profiles
with open(config.get('file_profiles'), 'r') as file:
data = json.load(file)
profiles = data.get('profiles', [])
if not profiles:
print("No profiles found in the profile file.")
return False
# Set up prompts for each profile
prompts = []
max_recommended_questions = config.get('max_recommended_questions')
for profile in profiles:
client_prompt = render_prompt_from_path(
config.get('template_client_prompt'),
profile,
max_questions=max_recommended_questions
)
advisor_prompt = render_prompt_from_path(
config.get('template_advisor_prompt'),
profile
)
if not client_prompt or not advisor_prompt:
print(f"Failed to render prompts for profile {profile.get('id', 'unknown')}")
continue
prompts.append({
'profile_id': profile.get('id'),
'client_prompt': client_prompt,
'advisor_prompt': advisor_prompt,
'financial_goal': profile.get('financial_goal', 'Unknown')
})
# Generate conversations
all_conversations = []
for prompt in prompts:
conversation = initiate_conversation(
prompt['client_prompt'],
prompt['advisor_prompt'],
prompt['financial_goal'],
client_api,
advisor_api,
max_recommended_questions
)
all_conversations.extend(conversation)
print(f"Generated conversation for profile {prompt['profile_id']} with {len(conversation)} turns.")
# Save conversations
success = save_conversation(all_conversations, config.get('file_output'))
if success:
print(f"Successfully saved {len(all_conversations)} conversation turns to {config.get('file_output')}")
else:
print("Failed to save conversations.")
return False
return True
except Exception as e:
print(f"Error during generation: {str(e)}")
return False