# train_grpo.pyimport re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
# from peft import LoraConfigfrom trl import GRPOConfig, GRPOTrainer
import os
os.environ['CUDA_VISIBLE_DEVICES']='3'
os.environ["WANDB_SILENT"]="true"# Load and prep dataset
SYSTEM_PROMPT ="""
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT ="""\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""defextract_xml_answer(text:str)->str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]return answer.strip()defextract_hash_answer(text:str)->str|None:if"####"notin text:returnNonereturn text.split("####")[1].strip()# uncomment middle messages for 1-shot promptingdefget_gsm8k_questions(split ="train")-> Dataset:# data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = load_dataset('./openai/gsm8k')[split]# type: ignore
data = data.map(lambda x:{# type: ignore'prompt':[{'role':'system','content': SYSTEM_PROMPT},{'role':'user','content': x['question']}],'answer': extract_hash_answer(x['answer'])})# type: ignore"""
{
'question': 'Carolyn counts 6 monkeys and 6 birds in the tree outside her window. \
Then two of the monkeys each eat one bird. What percent of the animals outside her window are monkeys now?', \
'answer': '60', 'prompt': [{'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n', 'role': 'system'}, \
{'content': 'Carolyn counts 6 monkeys and 6 birds in the tree outside her window. Then two of the monkeys each eat one bird. \
What percent of the animals outside her window are monkeys now?', 'role': 'user'}]
}
"""return data # type: ignore
dataset = get_gsm8k_questions()# Reward functionsdefcorrectness_reward_func(prompts, completions, answer,**kwargs)->list[float]:print('-'*50)print(prompts)print('-'*50)print(completions)print('-'*50)print(answer)print('-'*50)"""
[[{'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n', 'role': 'system'}, \
{'content': 'Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. \
Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade \
Ahmed needs to get to beat Emily if all grades are whole numbers?', 'role': 'user'}], [{'content':......}]] # 16个一样的prompts
completions: 模型输出的格式为[[{}],[{}],...]
[[{'role': 'assistant', 'content': "To determine the minimum grade Ahmed needs to get to beat Emily,
we first need to calculate the total grade the class will have for all assignments. The class has 9 assignments, and Emily has already achieved a 92 in all assignments.
Therefore, the total grade for Emily is already 92.\n\nSince Emily's final assignment is worth the same amount as all the other assignments, Emily will score another 90 on that assignment.
Let's denote Ahmed's score on the final assignment as \\(x\\). The total grade for Emily will then be:\n\n\\[ x + 90 = 92 + 90 = 182 \\]\n\nLet \\(y\\) be Ahmed's score on the final assignment.
The total grade for Ahmed will be:\n\n\\[ y + x + 91 = 91 + y + 91 = 92 + y \\]\n\nTo beat Emily, Ahmed's score must be greater than 92. Therefore, we need:\n\n\\[ 92 + y > 182 \\]\n\nSolving
this inequality for \\(y\\):\n\n\\[ y > 182 - 92 \\]\n\\[ y > 90 \\]\n\nSince all grades are whole numbers, the smallest whole number greater"}],...] # 16个模型推理的输出
answer: ['100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100', '100']
"""
responses =[completion[0]['content']for completion in completions]
q = prompts[0][-1]['content']
extracted_responses =[extract_xml_answer(r)for r in responses]print('-'*20,f"Question:\n{q}",f"\nAnswer:\n{answer[0]}",f"\nResponse:\n{responses[0]}",f"\nExtracted:\n{extracted_responses[0]}")return[2.0if r == a else0.0for r, a inzip(extracted_responses, answer)]defint_reward_func(completions,**kwargs)->list[float]:
responses =[completion[0]['content']for completion in completions]
extracted_responses =[extract_xml_answer(r)for r in responses]return[0.5if r.isdigit()else0.0for r in extracted_responses]defstrict_format_reward_func(completions,**kwargs)->list[float]:"""Reward function that checks if the completion has a specific format."""
pattern =r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses =[completion[0]["content"]for completion in completions]
matches =[re.match(pattern, r)for r in responses]return[0.5ifmatchelse0.0formatchin matches]defsoft_format_reward_func(completions,**kwargs)->list[float]:"""Reward function that checks if the completion has a specific format."""
pattern =r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses =[completion[0]["content"]for completion in completions]
matches =[re.match(pattern, r)for r in responses]return[0.5ifmatchelse0.0formatchin matches]defcount_xml(text)->float:
count =0.0if text.count("<reasoning>\n")==1:
count +=0.125if text.count("\n</reasoning>\n")==1:
count +=0.125if text.count("\n<answer>\n")==1:
count +=0.125
count -=len(text.split("\n</answer>\n")[-1])*0.001if text.count("\n</answer>")==1:
count +=0.125
count -=(len(text.split("\n</answer>")[-1])-1)*0.001return count
defxmlcount_reward_func(completions,**kwargs)->list[float]:
contents =[completion[0]["content"]for completion in completions]return[count_xml(c)for c in contents]#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name ="./models/Qwen/Qwen2.5-0.5B-Instruct"if"Llama"in model_name:
output_dir ="outputs/Llama-1B-GRPO"
run_name ="Llama-1B-GRPO-gsm8k"else:
output_dir="outputs/Qwen-0.5B-GRPO"
run_name="Qwen-0.5B-GRPO-gsm8k"
training_args = GRPOConfig(
output_dir=output_dir,
run_name=run_name,
learning_rate=5e-6,
adam_beta1 =0.9,
adam_beta2 =0.99,
weight_decay =0.1,
warmup_ratio =0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=16,
max_prompt_length=256,
max_completion_length=256,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
report_to="wandb",
log_on_each_node=False)
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation=None).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func],
args=training_args,
train_dataset=dataset,#peft_config=peft_config)
trainer.train()
trainer.save_model(output_dir)