Modifying Large Language Model Post-Training for Diverse Creative Writing
TL;DR Highlight
A technique that assigns higher weights to 'rare, high-quality samples' during DPO/ORPO training to maintain GPT-4o-level quality while boosting output diversity to human-level
Who Should Read
ML engineers fine-tuning LLMs for creative writing, story generation, and content creation who are struggling with overly uniform outputs — especially teams running DPO/ORPO-based training pipelines.
Core Mechanics
- Quantitatively confirmed that state-of-the-art instruction-tuned models like GPT-4o, Claude-3.5-Sonnet, and DeepSeek-R1 achieve high quality but exhibit significantly lower output diversity than human-written data
- Core idea: a 'deviation' score — representing how different a response is from other responses to the same prompt — is used as a weight in the DPO/ORPO loss function
- Responses that are rare but high-quality receive higher weights during training, guiding the model away from converging on 'common patterns'
- A Llama-3.1-8B-based DDPO-both model achieves diversity levels comparable to human-written data (Gold) on both semantic and style diversity, while matching GPT-4o on quality
- When the number of responses per prompt is too low (4 or fewer), quality degrades — but this can be addressed by setting a minimum deviation threshold or using only high-quality responses
- Outperforms DivPO (an existing diversity-enhancing technique) in both diversity and data efficiency — DivPO wastes data through filtering, whereas this method utilizes all data
Evidence
- Human evaluation of DDPO-both vs. GPT-4o: diversity win rate 100% vs. 0% (p < 0.001), quality win rate 68% vs. 24% (p < 0.001)
- DDPO-both vs. DPO: diversity win rate 62% vs. 26% (p < 0.001), no statistically significant difference in quality (p > 0.1)
- Llama-3.1-8B DDPO-both reaches nearly the same semantic diversity as human-written Gold data, while maintaining a reddit-reward slightly below GPT-4o-iter
- With 6 or more responses per prompt, DDPO-both achieves higher diversity and equivalent quality compared to DPO/DivPO; quality degradation only occurs with 4 or fewer responses
How to Apply
- In an existing DPO training pipeline, embed multiple responses to the same prompt (e.g., using jinaai/jina-embeddings-v3), compute each response's average cosine distance as its deviation, and multiply it as a loss weight. The code change amounts to a single-line modification to the loss function.
- When building a creative writing dataset, it's important to collect at least 6–8 responses per prompt. UGC (user-generated content) platforms like r/writingPrompts are well-suited to this requirement.
- To capture both semantic diversity and style diversity, combine the deviations from two embedding models using a geometric mean, as in DDPO-both. If only one type of diversity is needed, use only the corresponding embedding model.
Code Example
# DDPO core idea: weight DPO loss using deviation
import torch
from sentence_transformers import SentenceTransformer
import numpy as np
def compute_deviation(responses: list[str], embedder) -> np.ndarray:
"""
Compute deviation for responses to the same prompt
deviation = average cosine distance from all other responses
"""
embeddings = embedder.encode(responses, normalize_embeddings=True)
n = len(embeddings)
deviations = []
for i in range(n):
dists = []
for j in range(n):
if i != j:
cos_dist = 1 - np.dot(embeddings[i], embeddings[j])
dists.append(cos_dist)
deviations.append(np.mean(dists))
deviations = np.array(deviations)
# Normalize: min=0, sum=count
d_min = deviations.min()
d_max = deviations.max()
if d_max > d_min:
deviations = (deviations - d_min) / (d_max - d_min)
else:
deviations = np.full(n, 0.5)
deviations = deviations / deviations.sum() * n
return deviations
def ddpo_loss(policy_logps_w, policy_logps_l, ref_logps_w, ref_logps_l,
deviation_w, beta=0.1):
"""
Diversified DPO loss
deviation_w: deviation weight of the winning response (higher when rarer)
"""
log_ratios_w = policy_logps_w - ref_logps_w
log_ratios_l = policy_logps_l - ref_logps_l
logits = beta * (log_ratios_w - log_ratios_l)
# Weight the loss by deviation (the key!)
loss = -deviation_w * torch.nn.functional.logsigmoid(logits)
return loss.mean()
# Usage example
embedder = SentenceTransformer('jinaai/jina-embeddings-v3', trust_remote_code=True)
responses = ["story A...", "story B...", "story C...", "story D..."]
deviations = compute_deviation(responses, embedder)
print(f"Deviations: {deviations}") # Higher values indicate rarer stylesTerminology
Related Resources
Original Abstract (Expand)
As creative writing tasks do not have singular correct answers, large language models (LLMs) trained to perform these tasks should be able to generate diverse valid outputs. However, LLM post-training often focuses on improving generation quality but neglects to facilitate output diversity. Hence, in creative writing generation, we investigate post-training approaches to promote both output diversity and quality. Our core idea is to include deviation -- the degree of difference between a training sample and all other samples with the same prompt -- in the training objective to facilitate learning from rare high-quality instances. By adopting our approach to direct preference optimization (DPO) and odds ratio preference optimization (ORPO), we demonstrate that we can promote the output diversity of trained models while minimally decreasing quality. Our best model with 8B parameters could achieve on-par diversity as a human-created dataset while having output quality similar to the best instruction-tuned models we examined, GPT-4o and DeepSeek-R1. We further validate our approaches with a human evaluation, an ablation, and a comparison to an existing diversification approach, DivPO.