Separate cache key storage from result storage
To store cache records separately from the cached value, you can configure a cache policy to use a custom storage location.
Here’s an example of a cache policy configured to store cache records in a local directory:
from prefect import task
from prefect.cache_policies import TASK_SOURCE, INPUTS
cache_policy = (TASK_SOURCE + INPUTS).configure(key_storage="/path/to/cache/storage")
@task(cache_policy=cache_policy)
def my_cached_task(x: int):
return x + 42
Cache records will be stored in the specified directory while the persisted results will continue to be stored in ~/prefect/storage
.
To store cache records in a remote object store such as S3, pass a storage block instead:
from prefect import task
from prefect.cache_policies import TASK_SOURCE, INPUTS
from prefect_aws import S3Bucket, AwsCredentials
s3_bucket = S3Bucket(
credentials=AwsCredentials(
aws_access_key_id="my-access-key-id",
aws_secret_access_key="my-secret-access-key",
),
bucket_name="my-bucket",
)
# save the block to ensure it is available across machines
s3_bucket.save("my-cache-records-bucket")
cache_policy = (TASK_SOURCE + INPUTS).configure(key_storage=s3_bucket)
@task(cache_policy=cache_policy)
def my_cached_task(x: int):
return x + 42
Storing cache records in a remote object store allows you to share cache records across multiple machines.
Isolate cache access
You can control concurrent access to cache records by setting the isolation_level
parameter on the cache policy. Prefect supports two isolation levels: READ_COMMITTED
and SERIALIZABLE
.
By default, cache records operate with a READ_COMMITTED
isolation level. This guarantees that reading a cache record will see the latest committed cache value,
but allows multiple executions of the same task to occur simultaneously.
Consider the following example:
from prefect import task
from prefect.cache_policies import INPUTS
import threading
cache_policy = INPUTS
@task(cache_policy=cache_policy)
def my_task_version_1(x: int):
print("my_task_version_1 running")
return x + 42
@task(cache_policy=cache_policy)
def my_task_version_2(x: int):
print("my_task_version_2 running")
return x + 43
if __name__ == "__main__":
thread_1 = threading.Thread(target=my_task_version_1, args=(1,))
thread_2 = threading.Thread(target=my_task_version_2, args=(1,))
thread_1.start()
thread_2.start()
thread_1.join()
thread_2.join()
When running this script, both tasks will execute in parallel and perform work despite both tasks using the same cache key.
For stricter isolation, you can use the SERIALIZABLE
isolation level. This ensures that only one execution of a task occurs at a time for a given cache
record via a locking mechanism.
When setting isolation_level
to SERIALIZABLE
, you must also provide a lock_manager
that implements locking logic for your system.
Here’s an updated version of the previous example that uses SERIALIZABLE
isolation:
import threading
from prefect import task
from prefect.cache_policies import INPUTS
from prefect.locking.memory import MemoryLockManager
from prefect.transactions import IsolationLevel
cache_policy = INPUTS.configure(
isolation_level=IsolationLevel.SERIALIZABLE,
lock_manager=MemoryLockManager(),
)
@task(cache_policy=cache_policy)
def my_task_version_1(x: int):
print("my_task_version_1 running")
return x + 42
@task(cache_policy=cache_policy)
def my_task_version_2(x: int):
print("my_task_version_2 running")
return x + 43
if __name__ == "__main__":
thread_1 = threading.Thread(target=my_task_version_1, args=(2,))
thread_2 = threading.Thread(target=my_task_version_2, args=(2,))
thread_1.start()
thread_2.start()
thread_1.join()
thread_2.join()
In this example, only one of the tasks will run and the other will use the cached value.
Locking in a distributed setting
To manage locks in a distributed setting, you will need to use a storage system for locks that is accessible by all of your execution infrastructure.
We recommend using the RedisLockManager
provided by prefect-redis
in conjunction with a shared Redis instance:
from prefect import task
from prefect.cache_policies import TASK_SOURCE, INPUTS
from prefect.transactions import IsolationLevel
from prefect_redis import RedisLockManager
cache_policy = (INPUTS + TASK_SOURCE).configure(
isolation_level=IsolationLevel.SERIALIZABLE,
lock_manager=RedisLockManager(host="my-redis-host"),
)
@task(cache_policy=cache_policy)
def my_cached_task(x: int):
return x + 42
Coordinate caching across multiple tasks
To coordinate cache writes across tasks, you can run multiple tasks within a single transaction.
from prefect import task, flow
from prefect.transactions import transaction
@task(cache_key_fn=lambda *args, **kwargs: "static-key-1")
def load_data():
return "some-data"
@task(cache_key_fn=lambda *args, **kwargs: "static-key-2")
def process_data(data, fail):
if fail:
raise RuntimeError("Error! Abort!")
return len(data)
@flow
def multi_task_cache(fail: bool = True):
with transaction():
data = load_data()
process_data(data=data, fail=fail)
When this flow is run with the default parameter values it will fail on the process_data
task after the load_data
task has succeeded.
However, because caches are only written to when a transaction is committed, the load_data
task will not write a result to its cache key location until
the process_data
task succeeds as well.
On a subsequent run with fail=False
, both tasks will be re-executed and the results will be cached.
Handling Non-Serializable Objects
You may have task inputs that can’t (or shouldn’t) be serialized as part of the cache key. There are two direct approaches to handle this, both of which based on the same idea.
You can adjust the serialization logic to only serialize certain properties of an input:
- Using a custom cache key function:
from prefect import flow, task
from prefect.cache_policies import CacheKeyFnPolicy, RUN_ID
from prefect.context import TaskRunContext
from pydantic import BaseModel, ConfigDict
class NotSerializable:
def __getstate__(self):
raise TypeError("NooOoOOo! I will not be serialized!")
class ContainsNonSerializableObject(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str
bad_object: NotSerializable
def custom_cache_key_fn(context: TaskRunContext, parameters: dict) -> str:
return parameters["some_object"].name
@task(cache_policy=CacheKeyFnPolicy(cache_key_fn=custom_cache_key_fn) + RUN_ID)
def use_object(some_object: ContainsNonSerializableObject) -> str:
return f"Used {some_object.name}"
@flow
def demo_flow():
obj = ContainsNonSerializableObject(name="test", bad_object=NotSerializable())
state = use_object(obj, return_state=True) # Not cached!
assert state.name == "Completed"
other_state = use_object(obj, return_state=True) # Cached!
assert other_state.name == "Cached"
assert state.result() == other_state.result()
- Using Pydantic’s custom serialization on your input types:
from pydantic import BaseModel, ConfigDict, model_serializer
from prefect import flow, task
from prefect.cache_policies import INPUTS, RUN_ID
class NotSerializable:
def __getstate__(self):
raise TypeError("NooOoOOo! I will not be serialized!")
class ContainsNonSerializableObject(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str
bad_object: NotSerializable
@model_serializer
def ser_model(self) -> dict:
"""Only serialize the name, not the problematic object"""
return {"name": self.name}
@task(cache_policy=INPUTS + RUN_ID)
def use_object(some_object: ContainsNonSerializableObject) -> str:
return f"Used {some_object.name}"
@flow
def demo_flow():
some_object = ContainsNonSerializableObject(
name="test",
bad_object=NotSerializable()
)
state = use_object(some_object, return_state=True) # Not cached!
assert state.name == "Completed"
other_state = use_object(some_object, return_state=True) # Cached!
assert other_state.name == "Cached"
assert state.result() == other_state.result()
Choose the approach that best fits your needs:
- Use Pydantic models when you want consistent serialization across your application
- Use custom cache key functions when you need different caching logic for different tasks