CSRF protection in Phoenix with Sec-Fetch-Site

Phoenix projects created with mix phx.new are protected from Cross-Site Request Forgery (CSRF) attacks by CSRF tokens.

While simple & effective, CSRF tokens have some annoying downsides:

The good news is that since 2023, all major browsers send a Sec-Fetch-Site header with every request, which lets web servers determine whether the request is cross-origin.

Filippo Valsorda, Go cryptography maintainer, implemented Go v1.25’s CrossOriginProtection middleware using Sec-Fetch-Site and Origin headers. By reading through Filippo’s notes on how this middleware works and CrossOriginProtection’s source code, I was able to write a Sec-Fetch-Site based Plug for CSRF protection that I use in my Phoenix apps.

The plug

defmodule MyAppWeb.CrossOriginProtection do
  @moduledoc """
  Implements protections against Cross-Site Request Forgery (CSRF) by rejecting
  non-safe cross-origin browser requests.

  Based on Go's [CrossOriginProtection middleware](https://cs.opensource.google/go/go/+/refs/tags/go1.25rc2:src/net/http/csrf.go)

  Cross-origin requests are currently detected with the `Sec-Fetch-Site` header
  or by comparing the hostname of the `Origin` header with the Host header.

  The GET, HEAD, and OPTIONS methods are safe methods and are always allowed.
  We do not perform any state changing actions due to requests with safe methods.

  Requests without Sec-Fetch-Site or Origin headers are currently assumed to be
  either same-origin or non-browser requests, and are allowed.
  """
  @behaviour Plug

  @safe_methods ["GET", "HEAD", "OPTIONS"]

  @impl Plug
  def init(opts \\ []), do: opts

  @impl Plug
  def call(conn, opts) do
    bypass_patterns = Keyword.get(opts, :bypass_patterns)

    cond do
      conn.method in @safe_methods ->
        conn

      valid_sec_fetch_site_and_origin_headers?(conn) ->
        conn

      request_exempt?(conn, bypass_patterns) ->
        conn

      true ->
        conn
        |> Plug.Conn.resp(403, "cross-origin request detected")
        |> Plug.Conn.halt()
    end
  end

  defp valid_sec_fetch_site_and_origin_headers?(conn) do
    case get_req_header(conn, "sec-fetch-site") do
      # No Sec-Fetch-Site header is present.
      # Fallthrough to check the Origin header.
      nil ->
        valid_origin_header?(conn)

      "" ->
        valid_origin_header?(conn)

      "same-origin" ->
        true

      "none" ->
        true

      _ ->
        false
    end
  end

  defp valid_origin_header?(conn) do
    case get_req_header(conn, "origin") do
      # Neither Sec-Fetch-Site nor Origin headers are present.
      # Either the request is same-origin or not a browser request.
      nil ->
        true

      "" ->
        true

      origin ->
        origin_uri = URI.parse(origin)

        if is_binary(origin_uri.host) && is_integer(origin_uri.port) do
          origin_host = "#{origin_uri.host}:#{Integer.to_string(origin_uri.port)}"
          req_host = "#{conn.host}:443"

          origin_host == req_host
        else
          false
        end
    end
  end

  defp request_exempt?(_conn, nil),
    do: false

  defp request_exempt?(conn, bypass_patterns) do
    request_path = Path.expand(conn.request_path)

    Enum.any?(bypass_patterns, fn pattern ->
      String.starts_with?(request_path, Path.expand(pattern))
    end)
  end

  defp get_req_header(conn, header),
    do: conn |> Plug.Conn.get_req_header(header) |> List.first()
end
Expand to see unit tests
defmodule MyAppWeb.CrossOriginProtectionTest do
  use ExUnit.Case, async: true
  import Plug.Test
  import Plug.Conn

  alias Plug.Conn.Status
  alias MyAppWeb.CrossOriginProtection

  def call(conn, csrf_plug_opts \\ []) do
    conn
    |> CrossOriginProtection.call(CrossOriginProtection.init(csrf_plug_opts))
  end

  # Tests adapted from https://cs.opensource.google/go/go/+/refs/tags/go1.25rc2:src/net/http/csrf_test.go

  describe "Sec-Fetch-Site header:" do
    for {name, method, sec_fetch_site, origin, expected_status} <- [
          {"same-origin allowed", "POST", "same-origin", "", Status.code(:ok)},
          {"none allowed", "POST", "none", "", Status.code(:ok)},
          {"cross-site blocked", "POST", "cross-site", "", Status.code(:forbidden)},
          {"same-site blocked", "POST", "same-site", "", Status.code(:forbidden)},

          # No header
          {"no header with no origin", "POST", "", "", Status.code(:ok)},
          {"no header with matching origin", "POST", "", "https://example.com", Status.code(:ok)},
          {"no header with mismatched origin", "POST", "", "https://attacker.example",
           Status.code(:forbidden)},
          {"no header with null origin", "POST", "", "null", Status.code(:forbidden)},

          # Safe methods
          {"GET allowed", "GET", "cross-site", "", Status.code(:ok)},
          {"HEAD allowed", "HEAD", "cross-site", "", Status.code(:ok)},
          {"OPTIONS allowed", "OPTIONS", "cross-site", "", Status.code(:ok)},
          {"PUT blocked", "PUT", "cross-site", "", Status.code(:forbidden)}
        ] do
      @tag method: method,
           sec_fetch_site: sec_fetch_site,
           origin: origin,
           expected_status: expected_status
      test name, context do
        conn = conn(context.method, "https://example.com")

        conn =
          if context.sec_fetch_site != "" do
            put_req_header(conn, "sec-fetch-site", context.sec_fetch_site)
          else
            conn
          end

        conn =
          if context.origin != "" do
            put_req_header(conn, "origin", context.origin)
          else
            conn
          end

        conn = call(conn)

        if context.expected_status < 400 do
          refute conn.halted
        else
          assert conn.halted
          assert conn.status == context.expected_status
        end
      end
    end
  end

  describe "pattern bypass:" do
    for {name, path, sec_fetch_site, expected_status} <- [
          {"bypass path without sec-fetch-site", "/bypass/", "", Status.code(:ok)},
          {"bypass path with cross-site", "/bypass/", "cross-site", Status.code(:ok)},
          {"non-bypass path without sec-fetch-site", "/api/", "", Status.code(:forbidden)},
          {"non-bypass path with cross-site", "/api/", "cross-site", Status.code(:forbidden)},
          {"redirect to bypass path without ..", "/foo/../bypass/bar", "", Status.code(:ok)},
          {"redirect to bypass path with trailing slash", "/bypass", "", Status.code(:ok)},
          {"redirect to non-bypass path with ..", "/foo/../api/bar", "", Status.code(:forbidden)},
          {"redirect to non-bypass path with trailing slash", "/api", "", Status.code(:forbidden)}
        ] do
      @tag path: path, sec_fetch_site: sec_fetch_site, expected_status: expected_status
      test name, context do
        conn =
          conn("POST", "https://example.com" <> context.path)
          |> put_req_header("origin", "https://attacker.example")

        conn =
          if context.sec_fetch_site != "" do
            put_req_header(conn, "sec-fetch-site", context.sec_fetch_site)
          else
            conn
          end

        conn = call(conn, bypass_patterns: ["/bypass/"])

        if context.expected_status < 400 do
          refute conn.halted
        else
          assert conn.halted
          assert conn.status == context.expected_status
        end
      end
    end
  end
end

Usage

  1. Update your router.ex pipelines:
-  plug :protect_from_forgery
+  plug MyAppWeb.CrossOriginProtection
  1. Remove references to get_csrf_token() in your layouts:
- <meta name="csrf-token" content={get_csrf_token()} />
  1. Remove references to delete_csrf_token() in user_auth.ex:
 defp renew_session(conn) do
-  delete_csrf_token()

   conn
   |> configure_session(renew: true)
   |> clear_session()
 end
  1. Update your WebSocket transport config to check origin headers instead of CSRF tokens:
 socket "/live", Phoenix.LiveView.Socket,
   websocket: [
+    check_origin: true,
+    check_csrf: false,
     connect_info: [session: @session_options]
   ]
  1. Finally, update any frontend JavaScript that expects to find a token in meta[name='csrf-token'] for use with fetch/XHR requests.