Separate cache key storage from result storage

To store cache records separately from the cached value, you can configure a cache policy to use a custom storage location.

Here’s an example of a cache policy configured to store cache records in a local directory:

from prefect import task
from prefect.cache_policies import TASK_SOURCE, INPUTS

cache_policy = (TASK_SOURCE + INPUTS).configure(key_storage="/path/to/cache/storage")

@task(cache_policy=cache_policy)
def my_cached_task(x: int):
    return x + 42

Cache records will be stored in the specified directory while the persisted results will continue to be stored in ~/prefect/storage.

To store cache records in a remote object store such as S3, pass a storage block instead:

from prefect import task
from prefect.cache_policies import TASK_SOURCE, INPUTS

from prefect_aws import S3Bucket, AwsCredentials

s3_bucket = S3Bucket(
    credentials=AwsCredentials(
        aws_access_key_id="my-access-key-id",
        aws_secret_access_key="my-secret-access-key",
    ),
    bucket_name="my-bucket",
)
# save the block to ensure it is available across machines
s3_bucket.save("my-cache-records-bucket")

cache_policy = (TASK_SOURCE + INPUTS).configure(key_storage=s3_bucket)

@task(cache_policy=cache_policy)
def my_cached_task(x: int):
    return x + 42

Storing cache records in a remote object store allows you to share cache records across multiple machines.

Isolate cache access

You can control concurrent access to cache records by setting the isolation_level parameter on the cache policy. Prefect supports two isolation levels: READ_COMMITTED and SERIALIZABLE.

By default, cache records operate with a READ_COMMITTED isolation level. This guarantees that reading a cache record will see the latest committed cache value, but allows multiple executions of the same task to occur simultaneously.

Consider the following example:

from prefect import task
from prefect.cache_policies import INPUTS
import threading


cache_policy = INPUTS

@task(cache_policy=cache_policy)
def my_task_version_1(x: int):
    print("my_task_version_1 running")
    return x + 42

@task(cache_policy=cache_policy)
def my_task_version_2(x: int):
    print("my_task_version_2 running")
    return x + 43

if __name__ == "__main__":
    thread_1 = threading.Thread(target=my_task_version_1, args=(1,))
    thread_2 = threading.Thread(target=my_task_version_2, args=(1,))

    thread_1.start()
    thread_2.start()

    thread_1.join()
    thread_2.join()

When running this script, both tasks will execute in parallel and perform work despite both tasks using the same cache key.

For stricter isolation, you can use the SERIALIZABLE isolation level. This ensures that only one execution of a task occurs at a time for a given cache record via a locking mechanism.

When setting isolation_level to SERIALIZABLE, you must also provide a lock_manager that implements locking logic for your system.

Here’s an updated version of the previous example that uses SERIALIZABLE isolation:

import threading

from prefect import task
from prefect.cache_policies import INPUTS
from prefect.locking.memory import MemoryLockManager
from prefect.transactions import IsolationLevel

cache_policy = INPUTS.configure(
    isolation_level=IsolationLevel.SERIALIZABLE,
    lock_manager=MemoryLockManager(),
)


@task(cache_policy=cache_policy)
def my_task_version_1(x: int):
    print("my_task_version_1 running")
    return x + 42


@task(cache_policy=cache_policy)
def my_task_version_2(x: int):
    print("my_task_version_2 running")
    return x + 43


if __name__ == "__main__":
    thread_1 = threading.Thread(target=my_task_version_1, args=(2,))
    thread_2 = threading.Thread(target=my_task_version_2, args=(2,))

    thread_1.start()
    thread_2.start()

    thread_1.join()
    thread_2.join()

In this example, only one of the tasks will run and the other will use the cached value.

Locking in a distributed setting

To manage locks in a distributed setting, you will need to use a storage system for locks that is accessible by all of your execution infrastructure.

We recommend using the RedisLockManager provided by prefect-redis in conjunction with a shared Redis instance:

from prefect import task
from prefect.cache_policies import TASK_SOURCE, INPUTS
from prefect.transactions import IsolationLevel

from prefect_redis import RedisLockManager

cache_policy = (INPUTS + TASK_SOURCE).configure(
    isolation_level=IsolationLevel.SERIALIZABLE,
    lock_manager=RedisLockManager(host="my-redis-host"),
)

@task(cache_policy=cache_policy)
def my_cached_task(x: int):
    return x + 42

Coordinate caching across multiple tasks

To coordinate cache writes across tasks, you can run multiple tasks within a single transaction.

from prefect import task, flow
from prefect.transactions import transaction


@task(cache_key_fn=lambda *args, **kwargs: "static-key-1")
def load_data():
    return "some-data"


@task(cache_key_fn=lambda *args, **kwargs: "static-key-2")
def process_data(data, fail):
    if fail:
        raise RuntimeError("Error! Abort!")

    return len(data)


@flow
def multi_task_cache(fail: bool = True):
    with transaction():
        data = load_data()
        process_data(data=data, fail=fail)

When this flow is run with the default parameter values it will fail on the process_data task after the load_data task has succeeded.

However, because caches are only written to when a transaction is committed, the load_data task will not write a result to its cache key location until the process_data task succeeds as well.

On a subsequent run with fail=False, both tasks will be re-executed and the results will be cached.

Handling Non-Serializable Objects

You may have task inputs that can’t (or shouldn’t) be serialized as part of the cache key. There are two direct approaches to handle this, both of which based on the same idea.

You can adjust the serialization logic to only serialize certain properties of an input:

  1. Using a custom cache key function:
from prefect import flow, task
from prefect.cache_policies import CacheKeyFnPolicy, RUN_ID
from prefect.context import TaskRunContext
from pydantic import BaseModel, ConfigDict

class NotSerializable:
    def __getstate__(self):
        raise TypeError("NooOoOOo! I will not be serialized!")

class ContainsNonSerializableObject(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    name: str
    bad_object: NotSerializable

def custom_cache_key_fn(context: TaskRunContext, parameters: dict) -> str:
    return parameters["some_object"].name

@task(cache_policy=CacheKeyFnPolicy(cache_key_fn=custom_cache_key_fn) + RUN_ID)
def use_object(some_object: ContainsNonSerializableObject) -> str:
    return f"Used {some_object.name}"


@flow
def demo_flow():
    obj = ContainsNonSerializableObject(name="test", bad_object=NotSerializable())
    state = use_object(obj, return_state=True) # Not cached!
    assert state.name == "Completed"
    other_state = use_object(obj, return_state=True) # Cached!
    assert other_state.name == "Cached"
    assert state.result() == other_state.result()
  1. Using Pydantic’s custom serialization on your input types:
from pydantic import BaseModel, ConfigDict, model_serializer
from prefect import flow, task
from prefect.cache_policies import INPUTS, RUN_ID

class NotSerializable:
    def __getstate__(self):
        raise TypeError("NooOoOOo! I will not be serialized!")

class ContainsNonSerializableObject(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    name: str
    bad_object: NotSerializable

    @model_serializer
    def ser_model(self) -> dict:
        """Only serialize the name, not the problematic object"""
        return {"name": self.name}

@task(cache_policy=INPUTS + RUN_ID)
def use_object(some_object: ContainsNonSerializableObject) -> str:
    return f"Used {some_object.name}"

@flow
def demo_flow():
    some_object = ContainsNonSerializableObject(
        name="test",
        bad_object=NotSerializable()
    )
    state = use_object(some_object, return_state=True) # Not cached!
    assert state.name == "Completed"
    other_state = use_object(some_object, return_state=True) # Cached!
    assert other_state.name == "Cached"
    assert state.result() == other_state.result()

Choose the approach that best fits your needs:

  • Use Pydantic models when you want consistent serialization across your application
  • Use custom cache key functions when you need different caching logic for different tasks