⚡Tutorial: How to Fine-tune gpt-oss
Learn step-by-step how to train OpenAI gpt-oss locally with Unsloth.
In this guide with screenshots, you'll learn to fine-tune your own custom gpt-oss model either locally on your machine or for free using Google Colab. We'll walk you through the entire process, from setup to running and saving your trained model.
Quickstart: Fine-tune gpt-oss-20b for free with our: Colab notebook
Unsloth gpt-oss fine-tuning, when compared to all other FA2 implementations, achieves 1.5× faster training, 70% reduction in VRAM use, and 10x longer context lengths - with no accuracy loss.
QLoRA requirements: gpt-oss-20b = 14GB VRAM • gpt-oss-120b = 65GB VRAM.
BF16 LoRA requirements: gpt-oss-20b = 44GB VRAM • gpt-oss-120b = 210GB VRAM.
We've updated the gpt-oss fine-tuning and inference notebooks, making them much more stable.
🌐 Colab gpt-oss Fine-tuning
This section covers fine-tuning gpt-oss using our Google Colab notebooks. You can also save and use the gpt-oss notebook into your favorite code editor and follow our local gpt-oss guide.
Configuring gpt-oss and Reasoning Effort
We’ll load gpt-oss-20b
using Unsloth's linearized version (as no other version will work).
Configure the following parameters:
max_seq_length = 1024
Recommended for quick testing and initial experiments.
load_in_4bit = True
Use
False
for LoRA training (note: setting this toFalse
will need at least 43GB VRAM). You MUST also setmodel_name = "unsloth/gpt-oss-20b-BF16"

You should see output similar to the example below. Note: We explicitly change the dtype
to float32
to ensure correct training behavior.

Fine-tuning Hyperparameters (LoRA)
Now it's time to adjust your training hyperparameters. For a deeper dive into how, when, and what to tune, check out our detailed hyperparameters guide.
This step adds LoRA adapters for parameter-efficient fine-tuning. Only about 1% of the model’s parameters are trained, which makes the process significantly more efficient.

Data Preparation
For this example, we will use the HuggingFaceH4/Multilingual-Thinking
. This dataset contains chain-of-thought reasoning examples derived from user questions translated from English into four additional languages.
This is the same dataset referenced in OpenAI's fine-tuning cookbook.
The goal of using a multilingual dataset is to help the model learn and generalize reasoning patterns across multiple languages.

gpt-oss introduces a reasoning effort system that controls how much reasoning the model performs. By default, the reasoning effort is set to low
, but you can change it by setting the reasoning_effort
parameter to low
, medium
or high
.
Example:
tokenizer.apply_chat_template(
text,
tokenize = False,
add_generation_prompt = False,
reasoning_effort = "medium",
)
To format the dataset, we apply a customized version of the gpt-oss prompt:
from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)
Let's inspect the dataset by printing the first example:
print(dataset[0]['text'])

One unique feature of gpt-oss is its use of the OpenAI Harmony format, which supports structured conversations, reasoning output, and tool calling. This format includes tags such as <|start|>
, <|message|>
, and <|return|>
.
Feel free to adapt the prompt and structure to suit your own dataset or use-case. For more guidance, refer to our dataset guide.
Train the model
We've pre-selected training hyperparameters for optimal results. However, you can modify them based on your specific use case. Refer to our hyperparameters guide.
In this example, we train for 60 steps to speed up the process. For a full training run, set num_train_epochs=1
and disable the step limiting by setting max_steps=None
.

During training, monitor the loss to ensure that it is decreasing over time. This confirms that the training process is functioning correctly.

Inference: Run your trained model
Now it's time to run inference with your fine-tuned model. You can modify the instruction and input, but leave the output blank.
In this example, we test the model's ability to reason in French by adding a specific instruction to the system prompt, following the same structure used in our dataset.

This should produce an output similar to:

Save/export your model
To save your fine-tuned model, it must be exported in the Safetensors format.
Remember: Saving or merging QLoRA fine-tuned models to GGUF is not yet supported, as currently, QLoRA fine-tuned gpt-oss models only work in Unsloth.
To use your fine-tuned gpt-oss models in other frameworks (e.g. Hugging Face, llama.cpp with GGUF), you must train with LoRA on our BF16 model (requires >43GB VRAM). This produces a BF16 fine-tuned model that can be exported and converted as needed. You can use our llama.cpp GGUF scripts to convert the model to GGUF.
To save your LoRA adapters locally and optionally push them to the Hugging Face Hub, follow these steps:
model.save_pretrained("finetuned_model)
tokenizer.save_pretrained("finetuned_model")
model.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF
tokenizer.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF

🖥️ Local gpt-oss Fine-tuning
This chapter covers fine-tuning gpt-oss on your local device. While gpt-oss-20b fine-tuning can operate on just 14GB VRAM, we recommend having at least 16GB VRAM available to ensure stable and reliable training runs.
Install Unsloth Locally
Ensure your device is Unsloth compatible and you can read our detailed installation guide.
Note that pip install unsloth
will not work for this setup, as we need to use the latest PyTorch, Triton and related packages. Install Unsloth using this specific command:
# We're installing the latest Torch, Triton, OpenAI's Triton kernels, Transformers and Unsloth!
!pip install --upgrade -qqq uv
try: import numpy; install_numpy = f"numpy=={numpy.__version__}"
except: install_numpy = "numpy"
!uv pip install -qqq \
"torch>=2.8.0" "triton>=3.4.0" {install_numpy} \
"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
"unsloth[base] @ git+https://github.com/unslothai/unsloth" \
torchvision bitsandbytes \
git+https://github.com/huggingface/transformers \
git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
Configuring gpt-oss and Reasoning Effort
We’ll load gpt-oss-20b
using Unsloth's linearized version (as no other version will work for QLoRA fine-tuning). Configure the following parameters:
max_seq_length = 2048
Recommended for quick testing and initial experiments.
load_in_4bit = True
Use
False
for LoRA training (note: setting this toFalse
will need at least 43GB VRAM). You MUST also setmodel_name = "unsloth/gpt-oss-20b-BF16"
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024
dtype = None
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
"unsloth/gpt-oss-20b-unsloth-bnb-4bit", # 20B model using bitsandbytes 4bit quantization
"unsloth/gpt-oss-120b-unsloth-bnb-4bit",
"unsloth/gpt-oss-20b", # 20B model using MXFP4 format
"unsloth/gpt-oss-120b",
] # More models at https://huggingface.co/unsloth
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/gpt-oss-20b",
dtype = dtype, # None for auto detection
max_seq_length = max_seq_length, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
full_finetuning = False, # [NEW!] We have full finetuning now!
# token = "hf_...", # use one if using gated models
)
You should see output similar to the example below. Note: We explicitly change the dtype
to float32
to ensure correct training behavior.
Fine-tuning Hyperparameters (LoRA)
Now it's time to adjust your training hyperparameters. For a deeper dive into how, when, and what to tune, check out our detailed hyperparameters guide.
This step adds LoRA adapters for parameter-efficient fine-tuning. Only about 1% of the model’s parameters are trained, which makes the process significantly more efficient.
model = FastLanguageModel.get_peft_model(
model,
r = 8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
Data Preparation
For this example, we will use the HuggingFaceH4/Multilingual-Thinking
. This dataset contains chain-of-thought reasoning examples derived from user questions translated from English into four additional languages.
This is the same dataset referenced in OpenAI's fine-tuning cookbook. The goal of using a multilingual dataset is to help the model learn and generalize reasoning patterns across multiple languages.
def formatting_prompts_func(examples):
convos = examples["messages"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
return { "text" : texts, }
pass
from datasets import load_dataset
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
dataset
gpt-oss introduces a reasoning effort system that controls how much reasoning the model performs. By default, the reasoning effort is set to low
, but you can change it by setting the reasoning_effort
parameter to low
, medium
or high
.
Example:
tokenizer.apply_chat_template(
text,
tokenize = False,
add_generation_prompt = False,
reasoning_effort = "medium",
)
To format the dataset, we apply a customized version of the gpt-oss prompt:
from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)
Let's inspect the dataset by printing the first example:
print(dataset[0]['text'])

One unique feature of gpt-oss is its use of the OpenAI Harmony format, which supports structured conversations, reasoning output, and tool calling. This format includes tags such as <|start|>
, <|message|>
, and <|return|>
.
Feel free to adapt the prompt and structure to suit your own dataset or use-case. For more guidance, refer to our dataset guide.
Train the model
We've pre-selected training hyperparameters for optimal results. However, you can modify them based on your specific use case. Refer to our hyperparameters guide.
In this example, we train for 60 steps to speed up the process. For a full training run, set num_train_epochs=1
and disable the step limiting by setting max_steps=None
.
from trl import SFTConfig, SFTTrainer
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
args = SFTConfig(
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4,
warmup_steps = 5,
# num_train_epochs = 1, # Set this for 1 full training run.
max_steps = 30,
learning_rate = 2e-4,
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # Use this for WandB etc
),
)
During training, monitor the loss to ensure that it is decreasing over time. This confirms that the training process is functioning correctly.

Inference: Run Your Trained Model
Now it's time to run inference with your fine-tuned model. You can modify the instruction and input, but leave the output blank.
In this example, we test the model's ability to reason in French by adding a specific instruction to the system prompt, following the same structure used in our dataset.
messages = [
{"role": "system", "content": "reasoning language: French\n\nYou are a helpful assistant that can solve mathematical problems."},
{"role": "user", "content": "Solve x^5 + 3x^4 - 10 = 3."},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
return_tensors = "pt",
return_dict = True,
reasoning_effort = "medium",
).to(model.device)
from transformers import TextStreamer
_ = model.generate(**inputs, max_new_tokens = 2048, streamer = TextStreamer(tokenizer))
This should produce an output similar to:

Save and Export Your Model
To save your fine-tuned model, it must be exported in the Safetensors format.
Remember: Saving or merging QLoRA fine-tuned models to GGUF is not yet supported, as currently, QLoRA fine-tuned gpt-oss models only work in Unsloth.
To use your fine-tuned gpt-oss models in other frameworks (e.g. Hugging Face, llama.cpp with GGUF), you must train with LoRA on our BF16 model (requires >43GB VRAM). This produces a BF16 fine-tuned model that can be exported and converted as needed. You can use our llama.cpp GGUF scripts to convert the model to GGUF.
To save your LoRA adapters locally and optionally push them to the Hugging Face Hub, follow these steps:
model.save_pretrained("finetuned_model)
tokenizer.save_pretrained("finetuned_model")
model.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF
tokenizer.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF
🏁 And that's it!
You've fine-tuned gpt-oss with Unsloth. We're currently working on RL and GRPO implementations, as well as improved model saving and running, so stay tuned.
As always, feel free to drop by our Discord or Reddit if you need any help.
❓FAQ (Frequently Asked Questions)
1. Can I export my model to use in Hugging Face, llama.cpp GGUF or vLLM later?
Yes you can but only if you do LoRA fine-tuning and utilize our bf16 weights for fine-tuning. This means you must set model_name = "unsloth/gpt-oss-20b-BF16".
Keep in mind VRAM usage will be 4x more so gpt-oss-20b will require about 45GB VRAM.
We are working on better exporting for the model so stay tuned.
2. Can I do fp4 or MXFP4 training with gpt-oss?
No, currently no framework supports fp4 or MXFP4 training. Unsloth however is the only framework to support QLoRA 4-bit fine-tuning for the model, enabling more than 4x less VRAM use.
3. Can I export my model to MXFP4 format after training?
No, currently no library or framework supports this.
4. Can I do Reinforcement Learning (RL) or GRPO with gpt-oss?
No, currently no library or framework supports RL for gpt-oss. We are working on it however it will be hard to fit it on Colab considering how large the model is.
Acknowledgements: A huge thank you to Eyera for contributing to this guide!
Last updated
Was this helpful?