OpenAI gpt-oss & all model types now supported!

Tutorial: How to Fine-tune gpt-oss

Learn step-by-step how to train OpenAI gpt-oss locally with Unsloth.

In this guide with screenshots, you'll learn to fine-tune your own custom gpt-oss model either locally on your machine or for free using Google Colab. We'll walk you through the entire process, from setup to running and saving your trained model.

Quickstart: Fine-tune gpt-oss-20b for free with our: Colab notebook

Unsloth gpt-oss fine-tuning, when compared to all other FA2 implementations, achieves 1.5× faster training, 70% reduction in VRAM use, and 10x longer context lengths - with no accuracy loss.

  • QLoRA requirements: gpt-oss-20b = 14GB VRAM • gpt-oss-120b = 65GB VRAM.

  • BF16 LoRA requirements: gpt-oss-20b = 44GB VRAM • gpt-oss-120b = 210GB VRAM.

Local GuideColab Guide

🌐 Colab gpt-oss Fine-tuning

This section covers fine-tuning gpt-oss using our Google Colab notebooks. You can also save and use the gpt-oss notebook into your favorite code editor and follow our local gpt-oss guide.

1

Install Unsloth (in Colab)

In Colab, run cells from top to bottom. Use Run all for the first pass. The first cell installs Unsloth (and related dependencies) and prints GPU/memory info. If a cell throws an error, simply re-run it.

2

Configuring gpt-oss and Reasoning Effort

We’ll load gpt-oss-20b using Unsloth's linearized version (as no other version will work).

Configure the following parameters:

  • max_seq_length = 1024

    • Recommended for quick testing and initial experiments.

  • load_in_4bit = True

    • Use False for LoRA training (note: setting this to False will need at least 43GB VRAM). You MUST also set model_name = "unsloth/gpt-oss-20b-BF16"

You should see output similar to the example below. Note: We explicitly change the dtype to float32 to ensure correct training behavior.

3

Fine-tuning Hyperparameters (LoRA)

Now it's time to adjust your training hyperparameters. For a deeper dive into how, when, and what to tune, check out our detailed hyperparameters guide.

To avoid overfitting, monitor your training loss and avoid setting these values too high.

This step adds LoRA adapters for parameter-efficient fine-tuning. Only about 1% of the model’s parameters are trained, which makes the process significantly more efficient.

4

Try Inference

In the notebook, there's a section called "Reasoning Effort" that demonstrates gpt-oss inference running in Colab. You can skip this step, but you'll still need to run the model later once you've finished fine-tuning it.

5

Data Preparation

For this example, we will use the HuggingFaceH4/Multilingual-Thinking. This dataset contains chain-of-thought reasoning examples derived from user questions translated from English into four additional languages.

This is the same dataset referenced in OpenAI's fine-tuning cookbook.

The goal of using a multilingual dataset is to help the model learn and generalize reasoning patterns across multiple languages.

gpt-oss introduces a reasoning effort system that controls how much reasoning the model performs. By default, the reasoning effort is set to low, but you can change it by setting the reasoning_effort parameter to low, medium or high.

Example:

tokenizer.apply_chat_template(
    text, 
    tokenize = False, 
    add_generation_prompt = False,
    reasoning_effort = "medium",
)

To format the dataset, we apply a customized version of the gpt-oss prompt:

from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)

Let's inspect the dataset by printing the first example:

print(dataset[0]['text'])

One unique feature of gpt-oss is its use of the OpenAI Harmony format, which supports structured conversations, reasoning output, and tool calling. This format includes tags such as <|start|> , <|message|> , and <|return|> .

🦥 Unsloth fixes the chat template to ensure it is correct. See this tweet for technical details on our template fix.

Feel free to adapt the prompt and structure to suit your own dataset or use-case. For more guidance, refer to our dataset guide.

6

Train the model

We've pre-selected training hyperparameters for optimal results. However, you can modify them based on your specific use case. Refer to our hyperparameters guide.

In this example, we train for 60 steps to speed up the process. For a full training run, set num_train_epochs=1 and disable the step limiting by setting max_steps=None.

During training, monitor the loss to ensure that it is decreasing over time. This confirms that the training process is functioning correctly.

7

Inference: Run your trained model

Now it's time to run inference with your fine-tuned model. You can modify the instruction and input, but leave the output blank.

In this example, we test the model's ability to reason in French by adding a specific instruction to the system prompt, following the same structure used in our dataset.

This should produce an output similar to:

8

Save/export your model

To save your fine-tuned model, it must be exported in the Safetensors format.

To save your LoRA adapters locally and optionally push them to the Hugging Face Hub, follow these steps:

model.save_pretrained("finetuned_model)
tokenizer.save_pretrained("finetuned_model")
model.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF
tokenizer.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF

🖥️ Local gpt-oss Fine-tuning

This chapter covers fine-tuning gpt-oss on your local device. While gpt-oss-20b fine-tuning can operate on just 14GB VRAM, we recommend having at least 16GB VRAM available to ensure stable and reliable training runs.

We recommend downloading or incorporating elements from our Colab notebooks into your local setup for easier use.

1

Install Unsloth Locally

Ensure your device is Unsloth compatible and you can read our detailed installation guide.

Note that pip install unsloth will not work for this setup, as we need to use the latest PyTorch, Triton and related packages. Install Unsloth using this specific command:

# We're installing the latest Torch, Triton, OpenAI's Triton kernels, Transformers and Unsloth!
!pip install --upgrade -qqq uv
try: import numpy; install_numpy = f"numpy=={numpy.__version__}"
except: install_numpy = "numpy"
!uv pip install -qqq \
    "torch>=2.8.0" "triton>=3.4.0" {install_numpy} \
    "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo" \
    "unsloth[base] @ git+https://github.com/unslothai/unsloth" \
    torchvision bitsandbytes \
    git+https://github.com/huggingface/transformers \
    git+https://github.com/triton-lang/triton.git@05b2c186c1b6c9a08375389d5efe9cb4c401c075#subdirectory=python/triton_kernels
2

Configuring gpt-oss and Reasoning Effort

We’ll load gpt-oss-20b using Unsloth's linearized version (as no other version will work for QLoRA fine-tuning). Configure the following parameters:

  • max_seq_length = 2048

    • Recommended for quick testing and initial experiments.

  • load_in_4bit = True

    • Use False for LoRA training (note: setting this to False will need at least 43GB VRAM). You MUST also set model_name = "unsloth/gpt-oss-20b-BF16"

from unsloth import FastLanguageModel
import torch
max_seq_length = 1024
dtype = None

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/gpt-oss-20b-unsloth-bnb-4bit", # 20B model using bitsandbytes 4bit quantization
    "unsloth/gpt-oss-120b-unsloth-bnb-4bit",
    "unsloth/gpt-oss-20b", # 20B model using MXFP4 format
    "unsloth/gpt-oss-120b",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gpt-oss-20b",
    dtype = dtype, # None for auto detection
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

You should see output similar to the example below. Note: We explicitly change the dtype to float32 to ensure correct training behavior.

3

Fine-tuning Hyperparameters (LoRA)

Now it's time to adjust your training hyperparameters. For a deeper dive into how, when, and what to tune, check out our detailed hyperparameters guide.

To avoid overfitting, monitor your training loss and avoid setting these values too high.

This step adds LoRA adapters for parameter-efficient fine-tuning. Only about 1% of the model’s parameters are trained, which makes the process significantly more efficient.

model = FastLanguageModel.get_peft_model(
    model,
    r = 8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)
4

Data Preparation

For this example, we will use the HuggingFaceH4/Multilingual-Thinking. This dataset contains chain-of-thought reasoning examples derived from user questions translated from English into four additional languages.

This is the same dataset referenced in OpenAI's fine-tuning cookbook. The goal of using a multilingual dataset is to help the model learn and generalize reasoning patterns across multiple languages.

def formatting_prompts_func(examples):
    convos = examples["messages"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }
pass

from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
dataset

gpt-oss introduces a reasoning effort system that controls how much reasoning the model performs. By default, the reasoning effort is set to low, but you can change it by setting the reasoning_effort parameter to low, medium or high.

Example:

tokenizer.apply_chat_template(
    text, 
    tokenize = False, 
    add_generation_prompt = False,
    reasoning_effort = "medium",
)

To format the dataset, we apply a customized version of the gpt-oss prompt:

from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched = True,)

Let's inspect the dataset by printing the first example:

print(dataset[0]['text'])

One unique feature of gpt-oss is its use of the OpenAI Harmony format, which supports structured conversations, reasoning output, and tool calling. This format includes tags such as <|start|> , <|message|> , and <|return|> .

🦥 Unsloth fixes the chat template to ensure it is correct. See this tweet for technical details on our template fix.

Feel free to adapt the prompt and structure to suit your own dataset or use-case. For more guidance, refer to our dataset guide.

5

Train the model

We've pre-selected training hyperparameters for optimal results. However, you can modify them based on your specific use case. Refer to our hyperparameters guide.

In this example, we train for 60 steps to speed up the process. For a full training run, set num_train_epochs=1 and disable the step limiting by setting max_steps=None.

from trl import SFTConfig, SFTTrainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 30,
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

During training, monitor the loss to ensure that it is decreasing over time. This confirms that the training process is functioning correctly.

6

Inference: Run Your Trained Model

Now it's time to run inference with your fine-tuned model. You can modify the instruction and input, but leave the output blank.

In this example, we test the model's ability to reason in French by adding a specific instruction to the system prompt, following the same structure used in our dataset.

messages = [
    {"role": "system", "content": "reasoning language: French\n\nYou are a helpful assistant that can solve mathematical problems."},
    {"role": "user", "content": "Solve x^5 + 3x^4 - 10 = 3."},
]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True,
    return_tensors = "pt",
    return_dict = True,
    reasoning_effort = "medium",
).to(model.device)
from transformers import TextStreamer
_ = model.generate(**inputs, max_new_tokens = 2048, streamer = TextStreamer(tokenizer))

This should produce an output similar to:

7

Save and Export Your Model

To save your fine-tuned model, it must be exported in the Safetensors format.

To save your LoRA adapters locally and optionally push them to the Hugging Face Hub, follow these steps:

model.save_pretrained("finetuned_model)
tokenizer.save_pretrained("finetuned_model")
model.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF
tokenizer.push_to_hub("hf_username/finetuned_model", token = "hf_...") # Save to HF

🏁 And that's it!

You've fine-tuned gpt-oss with Unsloth. We're currently working on RL and GRPO implementations, as well as improved model saving and running, so stay tuned.

As always, feel free to drop by our Discord or Reddit if you need any help.

❓FAQ (Frequently Asked Questions)

1. Can I export my model to use in Hugging Face, llama.cpp GGUF or vLLM later?

Yes you can but only if you do LoRA fine-tuning and utilize our bf16 weights for fine-tuning. This means you must set model_name = "unsloth/gpt-oss-20b-BF16". Keep in mind VRAM usage will be 4x more so gpt-oss-20b will require about 45GB VRAM.

We are working on better exporting for the model so stay tuned.

2. Can I do fp4 or MXFP4 training with gpt-oss?

No, currently no framework supports fp4 or MXFP4 training. Unsloth however is the only framework to support QLoRA 4-bit fine-tuning for the model, enabling more than 4x less VRAM use.

3. Can I export my model to MXFP4 format after training?

No, currently no library or framework supports this.

4. Can I do Reinforcement Learning (RL) or GRPO with gpt-oss?

No, currently no library or framework supports RL for gpt-oss. We are working on it however it will be hard to fit it on Colab considering how large the model is.


Acknowledgements: A huge thank you to Eyera for contributing to this guide!

Last updated

Was this helpful?