> ## Documentation Index
> Fetch the complete documentation index at: https://docs.lancedb.com/llms.txt
> Use this file to discover all available pages before exploring further.

# PyTorch Integration

> Learn how to use LanceDB with PyTorch for training and inference.

LanceDB provides a seamless integration with PyTorch for training and inference. This allows you to use LanceDB as a backend for your PyTorch models, and to use PyTorch for training and inference. You can use LanceDB to store your data, and PyTorch to train your models.

## Quickstart

The `Table` class in LanceDB implements a contract for a PyTorch
[Dataset](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.Dataset).
This means you can simply use a LanceDB table in a PyTorch dataloader directly.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
import lancedb
import torch
import pyarrow as pa

mem_db = lancedb.connect("memory://")
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))

# Any LanceDB table can be used as a PyTorch Dataset
dataloader = torch.utils.data.DataLoader(
    table, batch_size=1024, shuffle=True
)

for batch in dataloader:
    print(batch)
```

Although the `Table` class in LanceDB implements the `torch.utils.data.Dataset` interface, you may find that using
a table [Permutation](/training/) is more flexible.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(permutation)
```

## Output Formats

By default, a `Table` data loader will emit a `pyarrow.RecordBatch`.  To convert to a different format (such as a
`pytorch.Tensor`), you will need to provide a custom collate function.

The `Permutation` class is more flexible.  By default, the output will be a list of dicts.  This is the default output
format of standard data loaders and usually more convenient when you are getting started.  However, there is a
significant performance penalty converting from Arrow, Lance's internal representation, to this default format.

To address this, the `Permutation` class provides a set of builtin transform functions that can be applied to map
the Arrow data in different ways.  The `arrow` and `polars` formats will always avoid data copies.  However, `numpy`,
`pandas`, and `torch_col` formats will also avoid data copies in most cases.  The `python`, `python_col`, and
`torch` formats will all require at least one full copy of the data and are the slowest options.

### Using the torch\_col format with a torch data loader

The `torch_col` format is the most efficient way to convert from Arrow to a `torch.Tensor`.  It will convert the
entire Arrow batch to a *column-major* `torch.Tensor`.  In other words, given C columns and R rows, the resulting
Tensor will have shape `(C, R)`.  However, this format generates an error if you are using a
`torch.utils.data.DataLoader` with the default collation function:

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
```

This error occurs because the default collation function does not currently expect a single two-dimensional tensor.
It expects a list of tensors which it will then stack.  This is what is output by the `torch` format but that format
requires a data copy.  To avoid this error, and avoid data copies, you will need to provide a custom collation function
in addition to specifying the `torch_col` format.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).with_format("torch_col")
dataloader = torch.utils.data.DataLoader(permutation, collate_fn=lambda x: x)
```

This will now output a single two-dimensional tensor for each batch.

## Selecting columns

By default, the `Table` class will return all columns in the table when used as input to PyTorch. If you only need
a subset of columns, you can significantly reduce your I/O requirements by selecting only the columns you need.  The
`Permutation` class provides a `select_columns` method that provides this functionality.

```py Python icon=Python  theme={"theme":{"light":"vitesse-light","dark":"catppuccin-mocha"}}
from lancedb.permutation import Permutation

permutation = Permutation.identity(table).select_columns(["id", "prompt"])
dataloader = torch.utils.data.DataLoader(
    permutation, batch_size=1024, shuffle=True
)

for batch in dataloader:
    print(batch.schema)
```
