Building Conversational AI with LLMs and Agents
Appendix T: Distributed ML: PySpark, Databricks, and Ray

Databricks AI and Foundation Models

Big Picture

Databricks has evolved from a data engineering platform into a full-stack AI platform. With the Mosaic AI suite, Databricks provides managed infrastructure for training large models, hosted foundation model APIs, SQL-native LLM functions, MLflow-based lifecycle management, and a built-in vector search service. This section covers each layer of that stack, from fine-tuning custom models with Composer to building end-to-end RAG applications that keep all components inside the Databricks Lakehouse. Teams that already use Databricks for data engineering can reach production-grade AI applications without introducing separate infrastructure for training, serving, or retrieval.

T.4.1 Mosaic AI: Training LLMs on Databricks

Mosaic AI is the unified AI development layer inside Databricks, built on technology from MosaicML (acquired by Databricks in 2023). It bundles two open-source libraries: Composer (a PyTorch training library with built-in efficiency algorithms) and Streaming (a dataset format designed for efficient random-access reads from object storage during distributed training). Together they let teams train or fine-tune large models on Databricks clusters without managing Ray or DeepSpeed directly, though those backends can be used underneath. See Section T.3 for the underlying Ray integration and Appendix K for HuggingFace model loading patterns that combine naturally with Composer.

The Mosaic AI managed fine-tuning endpoint requires no cluster configuration. You submit a training job via the Databricks UI or REST API, specifying a base model, a Delta table of instruction pairs, and hyperparameters. Databricks provisions the compute, runs the training loop, logs results to MLflow, and registers the resulting model in Unity Catalog.

import requests
import json

# Submit a managed fine-tuning job via the Databricks REST API
DATABRICKS_HOST = "https://your-workspace.azuredatabricks.net"
TOKEN = "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"

payload = {
    "model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "train_data_path": "dbfs:/user/data/instruction_pairs_train",
    "eval_data_path": "dbfs:/user/data/instruction_pairs_eval",
    "training_duration": "3ep",           # 3 epochs
    "learning_rate": 5e-6,
    "register_to": "ml_catalog.llm_models.llama_support_ft",
    "data_prep_cluster_id": "0101-123456-abcdefgh",
}

response = requests.post(
    f"{DATABRICKS_HOST}/api/2.0/fine-tuning/runs/create",
    headers={"Authorization": f"Bearer {TOKEN}"},
    json=payload,
)
run_id = response.json()["run_id"]
print(f"Fine-tuning run started: {run_id}")
from composer import Trainer
from composer.models import HuggingFaceModel
from composer.optim import DecoupledAdamW
from streaming import StreamingDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16
)

# Wrap in Composer's HuggingFace adapter
composer_model = HuggingFaceModel(
    model=base_model,
    tokenizer=tokenizer,
    use_logits=True,
)

# Streaming dataset from DBFS (MDS format written by Spark)
train_dataset = StreamingDataset(
    remote="s3://my-bucket/training-data/mds/",
    local="/tmp/streaming_cache/",
    shuffle=True,
    batch_size=8,
)

trainer = Trainer(
    model=composer_model,
    train_dataloader=torch.utils.data.DataLoader(train_dataset, batch_size=8),
    max_duration="3ep",
    optimizers=DecoupledAdamW(composer_model.parameters(), lr=5e-6, weight_decay=0.0),
    device="gpu",
    precision="bf16_mixed",
    save_folder="/dbfs/user/checkpoints/llama_sft/",
    loggers=[],  # MLflow logger added via mlflow.start_run() context
)
trainer.fit()
from openai import OpenAI

# The openai client works with Databricks Foundation Model APIs
# by overriding the base_url and using a Databricks PAT as the key
client = OpenAI(
    api_key="dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
    base_url="https://your-workspace.azuredatabricks.net/serving-endpoints",
)

response = client.chat.completions.create(
    model="databricks-dbrx-instruct",   # or databricks-meta-llama-3-1-70b-instruct
    messages=[
        {"role": "system", "content": "You are a helpful data engineering assistant."},
        {"role": "user", "content": "Explain the difference between Delta Live Tables and standard Delta tables."},
    ],
    max_tokens=512,
    temperature=0.3,
)
print(response.choices[0].message.content)
Code Fragment T.4.1: Submitting a managed fine-tuning job via the Databricks REST API. The run logs automatically to MLflow and registers the finished model in Unity Catalog.

For teams that need more control, Composer can be used directly on a Databricks cluster. The StreamingDataset class reads sharded MDS-format data from DBFS or cloud storage with shuffle buffers that keep GPU utilisation high even at slow network speeds.

-- Classify customer support tickets by category using ai_classify()
SELECT
    ticket_id,
    ticket_text,
    ai_classify(
        ticket_text,
        ARRAY('billing', 'technical', 'account_access', 'feature_request', 'other')
    ) AS category,
    ai_extract(
        ticket_text,
        ARRAY('product_name', 'error_code', 'urgency_level')
    ) AS extracted_fields,
    ai_summarize(ticket_text, 50) AS short_summary
FROM raw_support_tickets
WHERE created_date >= CURRENT_DATE - INTERVAL 1 DAY;
Code Fragment T.4.2: Direct Composer training on a Databricks GPU cluster. StreamingDataset streams shards from object storage, avoiding the need to copy the full dataset to local disk before training begins.
Note

The MPT model family (MPT-7B, MPT-30B) was trained by MosaicML using Composer and Streaming. These models use ALiBi positional encodings and were designed for long-context use. They remain useful as base models for domain-specific fine-tuning, especially when a commercially licensable model is required. See Appendix K for loading MPT models from the HuggingFace Hub.

T.4.2 Foundation Model APIs

Databricks Foundation Model APIs provide pay-per-token access to hosted open-weight models (DBRX, Llama 3, Mistral, Mixtral) without the overhead of managing GPU infrastructure. The API follows the OpenAI Chat Completions format, so the standard openai client library works by pointing it at your Databricks workspace URL. This makes it straightforward to swap between hosted Databricks models and external providers such as OpenAI or Anthropic during development. For background on the OpenAI-compatible API contract, see Chapter 10.

-- Use ai_query() for custom prompts inside a Delta Live Tables pipeline
CREATE OR REFRESH STREAMING TABLE product_reviews_enriched AS
SELECT
    review_id,
    product_id,
    review_text,
    rating,
    -- Call the Foundation Model API for sentiment and entity extraction
    ai_query(
        'databricks-meta-llama-3-1-70b-instruct',
        CONCAT(
            'Analyze this product review. Return a JSON object with keys: ',
            '"sentiment" (positive/negative/neutral), ',
            '"mentioned_features" (array of product features mentioned), ',
            '"improvement_suggestions" (array of suggestions, empty if none). ',
            'Review: ', review_text
        )
    ) AS llm_analysis,
    -- Also classify sentiment directly
    ai_classify(review_text, ARRAY('positive', 'negative', 'neutral')) AS sentiment_label
FROM STREAM(LIVE.raw_product_reviews);
Code Fragment T.4.3: Accessing Databricks Foundation Model APIs using the OpenAI client library. The only changes from a standard OpenAI call are the base_url and the Databricks personal access token.

Provisioned throughput is an alternative pricing tier that reserves dedicated capacity measured in tokens per second. It is appropriate when latency must be predictable (no cold starts, no queuing behind other tenants) or when usage volume makes per-token pricing uneconomical. Provisioned throughput endpoints are created in the Databricks Serving UI and billed per token-per-second of reserved capacity per hour, regardless of actual usage.

Key Insight

For enterprise use cases, Databricks Foundation Model APIs have a governance advantage over external providers: all data stays inside your cloud tenant, audit logs are captured in Unity Catalog, and model access is governed by the same RBAC policies as your data. When comparing with OpenAI or Anthropic APIs, factor in data residency requirements, latency SLAs, and the operational cost of managing separate API keys and audit trails.

T.4.3 AI Functions: LLM Calls from SQL

Databricks AI Functions expose LLM capabilities as built-in SQL functions, allowing data engineers to invoke models directly inside SQL queries, Delta Live Tables pipelines, and dbt models. The core function is ai_query(), which calls any Foundation Model API endpoint. Wrapper functions like ai_classify(), ai_extract(), and ai_summarize() provide structured interfaces for common text analytics tasks. This integrates naturally with the Delta Lake architecture covered in Section T.2 and the PySpark pipelines from Section T.6.

import mlflow
import mlflow.openai
from mlflow.models import infer_signature
import pandas as pd

mlflow.set_experiment("/Users/you@company.com/llm-finetuning-v3")

with mlflow.start_run(run_name="llama-support-ft-eval"):
    # Log hyperparameters
    mlflow.log_params({
        "base_model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
        "learning_rate": 5e-6,
        "training_epochs": 3,
        "train_samples": 12400,
    })

    # Log prompt template used for fine-tuning as an artifact
    mlflow.log_text(
        "You are a helpful support agent. Answer the customer's question concisely.\n\nQuestion: {question}\n\nAnswer:",
        "prompt_template.txt",
    )

    # Evaluate the fine-tuned model with mlflow.evaluate()
    eval_data = pd.DataFrame({
        "inputs": ["How do I reset my password?", "Where is my order?"],
        "ground_truth": ["Go to Settings > Security > Reset Password.", "Check your order status at orders.example.com."],
    })

    results = mlflow.evaluate(
        model="endpoints:/ml_catalog.llm_models.llama_support_ft/1",
        data=eval_data,
        targets="ground_truth",
        model_type="question-answering",
        evaluators="default",
        extra_metrics=[
            mlflow.metrics.faithfulness(),
            mlflow.metrics.answer_relevance(),
            mlflow.metrics.toxicity(),
        ],
    )

    # Log aggregate metrics from the evaluation
    mlflow.log_metrics({
        "faithfulness_mean": results.metrics["faithfulness/v1/mean"],
        "relevance_mean": results.metrics["answer_relevance/v1/mean"],
        "toxicity_ratio": results.metrics["toxicity/v1/ratio"],
    })

    print(results.tables["eval_results_table"])
Code Fragment T.4.4: Using ai_classify(), ai_extract(), and ai_summarize() to enrich raw support tickets in a single SQL query. Results can be written directly to a Delta table.

ai_query() provides lower-level access, letting you pass a custom prompt and parse the response. It accepts a model endpoint name and a prompt string, and returns the model's text response as a SQL string.

from mlflow.tracking import MlflowClient

client = MlflowClient()
model_name = "ml_catalog.llm_models.llama_support_ft"

# Transition version 3 to Champion (replaces the previous champion)
client.set_registered_model_alias(
    name=model_name,
    alias="Champion",
    version="3",
)

# Create or update a Model Serving endpoint to use the new champion
import requests, json

DATABRICKS_HOST = "https://your-workspace.azuredatabricks.net"
TOKEN = "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"

endpoint_config = {
    "name": "support-llm-endpoint",
    "config": {
        "served_models": [{
            "name": "llama-support-ft-champion",
            "model_name": model_name,
            "model_version": "3",
            "workload_size": "Small",
            "scale_to_zero_enabled": True,
        }]
    }
}

requests.put(
    f"{DATABRICKS_HOST}/api/2.0/serving-endpoints/support-llm-endpoint/config",
    headers={"Authorization": f"Bearer {TOKEN}"},
    json=endpoint_config["config"],
)
Code Fragment T.4.5: Enriching a streaming Delta Live Tables pipeline with AI functions. The LLM call runs once per row as the stream processes; use provisioned throughput endpoints to control latency and cost at high row volumes.
Warning

AI functions add per-token cost to every SQL query or pipeline run that invokes them. At large scales (millions of rows), costs can exceed the compute cost of the query itself. Use column filters and partition pruning to limit rows processed, cache results in a Delta table rather than re-running on every query, and consider batching with the Python API instead of per-row SQL calls for bulk enrichment jobs.

T.4.4 MLflow for LLM Lifecycle Management

MLflow 2.x introduced first-class support for LLM workflows: logging prompts and completions as structured artifacts, evaluating outputs with LLM-specific metrics, and packaging custom LLM chains as pyfunc models for deployment. On Databricks, MLflow is pre-installed and integrated with Unity Catalog, so experiments, runs, and registered models all participate in the same access control and lineage tracking as your data assets. For the broader context of experiment tracking, see Appendix R; for LLM evaluation methodology, see Chapter 29.

from databricks.vector_search.client import VectorSearchClient

vsc = VectorSearchClient()

# Create a vector search endpoint (compute that serves queries)
vsc.create_endpoint(
    name="llm-rag-endpoint",
    endpoint_type="STANDARD",
)

# Create a Delta Sync index: Databricks computes embeddings automatically
# using the specified embedding model endpoint, and re-syncs when the
# source Delta table changes
index = vsc.create_delta_sync_index(
    endpoint_name="llm-rag-endpoint",
    source_table_name="ml_catalog.rag_data.product_docs",
    index_name="ml_catalog.rag_data.product_docs_index",
    pipeline_type="TRIGGERED",           # or CONTINUOUS for near-real-time sync
    primary_key="doc_id",
    embedding_source_column="content",   # column containing text to embed
    embedding_model_endpoint_name="databricks-bge-large-en",
)
Code Fragment T.4.6: Using mlflow.evaluate() to assess a fine-tuned model on faithfulness, relevance, and toxicity. Results are logged to the MLflow run for comparison across model versions.

After evaluation, models are promoted through the Unity Catalog Model Registry stages (None, Staging, Champion) via the MLflow client or Databricks UI. Deployment to a Databricks Model Serving endpoint can then be triggered programmatically, completing the CI/CD loop described in Section T.5.

# Query the vector search index at inference time
results = index.similarity_search(
    query_text="How do I configure SSO with Azure AD?",
    columns=["doc_id", "title", "content", "last_updated"],
    num_results=5,
    filters={"doc_type": "integration_guide"},  # optional metadata filter
)

for hit in results["result"]["data_array"]:
    doc_id, title, content, last_updated, score = hit
    print(f"[{score:.3f}] {title} ({doc_id})")
Code Fragment T.4.7: Promoting a model version to the Champion alias in Unity Catalog and updating a Model Serving endpoint to serve it. The alias decouples the endpoint configuration from the specific version number.

T.4.5 Databricks Vector Search

Databricks Vector Search is a managed vector index that lives alongside Delta tables in the Lakehouse. Unlike standalone vector databases, it synchronises automatically with its source Delta table: when rows are added, updated, or deleted, the index updates accordingly without requiring a separate ingestion pipeline. Embeddings can be computed by a Foundation Model API embedding endpoint and stored in the index, or you can supply pre-computed embedding columns. This end-to-end integration reduces the operational surface for RAG applications significantly compared to maintaining a separate service such as Pinecone or Weaviate alongside Databricks. For comparison with standalone vector databases, see Chapter 20.

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, monotonically_increasing_id, explode
from pyspark.sql.types import ArrayType, StructType, StructField, StringType, IntegerType

spark = SparkSession.builder.getOrCreate()

# UDF: split document text into overlapping chunks
def chunk_text(text: str, chunk_size: int = 512, overlap: int = 64):
    """Split text into chunks of ~chunk_size characters with overlap."""
    if not text:
        return []
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + chunk_size, len(text))
        chunks.append({"chunk_text": text[start:end], "chunk_index": len(chunks)})
        if end == len(text):
            break
        start += chunk_size - overlap
    return chunks

chunk_schema = ArrayType(StructType([
    StructField("chunk_text", StringType()),
    StructField("chunk_index", IntegerType()),
]))
chunk_udf = udf(chunk_text, chunk_schema)

# Explode raw documents into chunk rows
raw_docs = spark.table("ml_catalog.rag_data.raw_documents")

chunked = (
    raw_docs
    .withColumn("chunks", chunk_udf(col("content")))
    .select("doc_id", "title", "doc_type", "last_updated", explode("chunks").alias("chunk"))
    .select(
        "doc_id", "title", "doc_type", "last_updated",
        col("chunk.chunk_text").alias("content"),
        col("chunk.chunk_index").alias("chunk_index"),
    )
    .withColumn("chunk_id", monotonically_increasing_id().cast("string"))
)

chunked.write.format("delta").mode("append").saveAsTable("ml_catalog.rag_data.product_docs")
Code Fragment T.4.8: PySpark ingestion pipeline that chunks raw documents using an overlapping window UDF and writes them to the Delta table backing the Vector Search index. When this job runs, the index re-syncs automatically.
import mlflow
import mlflow.pyfunc
from databricks.vector_search.client import VectorSearchClient
from openai import OpenAI

class DatabricksRAGChain(mlflow.pyfunc.PythonModel):
    """MLflow pyfunc wrapper for a Databricks-native RAG chain."""

    def load_context(self, context):
        self.vsc = VectorSearchClient()
        self.index = self.vsc.get_index(
            endpoint_name="llm-rag-endpoint",
            index_name="ml_catalog.rag_data.product_docs_index",
        )
        self.llm = OpenAI(
            api_key=context.model_config["databricks_token"],
            base_url=context.model_config["databricks_host"] + "/serving-endpoints",
        )
        self.model_name = context.model_config.get(
            "generation_model", "databricks-dbrx-instruct"
        )

    def predict(self, context, model_input):
        questions = model_input["question"].tolist()
        answers = []

        for question in questions:
            # Retrieve top-5 relevant chunks
            results = self.index.similarity_search(
                query_text=question,
                columns=["title", "content"],
                num_results=5,
            )
            context_text = "\n\n".join(
                f"[{row[0]}]\n{row[1]}"
                for row in results["result"]["data_array"]
            )

            # Generate answer with retrieved context
            response = self.llm.chat.completions.create(
                model=self.model_name,
                messages=[
                    {
                        "role": "system",
                        "content": (
                            "Answer the question using only the provided context. "
                            "If the context does not contain the answer, say so."
                        ),
                    },
                    {
                        "role": "user",
                        "content": f"Context:\n{context_text}\n\nQuestion: {question}",
                    },
                ],
                max_tokens=512,
                temperature=0.1,
            )
            answers.append(response.choices[0].message.content)

        import pandas as pd
        return pd.DataFrame({"answer": answers})


# Log the chain to MLflow and register it
with mlflow.start_run(run_name="rag-chain-v1"):
    model_config = {
        "databricks_host": "https://your-workspace.azuredatabricks.net",
        "databricks_token": "{{secrets/rag-scope/databricks-token}}",
        "generation_model": "databricks-dbrx-instruct",
    }
    mlflow.pyfunc.log_model(
        artifact_path="rag_chain",
        python_model=DatabricksRAGChain(),
        model_config=model_config,
        registered_model_name="ml_catalog.rag_apps.product_docs_rag",
    )
Code Fragment T.4.9: A complete RAG chain packaged as an MLflow pyfunc model. The predict method queries the Vector Search index with metadata filters, assembles context, and calls the generation LLM. Once logged and registered, the chain deploys to a Databricks Model Serving endpoint.
Practical Example

A team maintaining a product documentation portal can point a Delta Sync index at the product_docs Delta table written by their documentation pipeline. When writers publish new articles, the Spark job appends rows to the Delta table, and the index syncs automatically within minutes (TRIGGERED mode) or seconds (CONTINUOUS mode). No separate vector database ETL job is needed, and the index is always consistent with the source of truth.

T.4.6 Building RAG Applications on Databricks

The Databricks Lakehouse provides every component needed for a production RAG system: Delta Lake for document storage, PySpark for chunking and preprocessing (see Section T.6), Foundation Model APIs for embedding and generation, Vector Search for retrieval, and MLflow for chain evaluation and deployment. This section assembles those components into a complete RAG pipeline, from raw document ingestion through to a deployed Model Serving endpoint. The RAG architecture itself is covered in depth in Chapter 20.

 Documents           Chunking             Embedding             Vector Index
┌──────────┐   ┌──────────────┐   ┌──────────────────┐   ┌──────────────────┐
│  Delta   │──▶│   PySpark    │──▶│  Foundation      │──▶│  Databricks      │
│  Lake    │   │   UDF chunk  │   │  Model API       │   │  Vector Search   │
│  (raw)   │   │   overlap    │   │  (BGE/E5)        │   │  (auto-synced)   │
└──────────┘   └──────────────┘   └──────────────────┘   └──────────┬───────┘
                                                                     │ retrieve
                                                          ┌──────────▼───────┐
                                                          │  Generation LLM  │
                Query ──────────────────────────────────▶│  (DBRX / Llama)  │
                                                          │  + MLflow chain  │
                                                          └──────────────────┘
        
Figure T.4.1: End-to-end RAG architecture on Databricks. All components except the user query live inside the Lakehouse, sharing Unity Catalog governance and MLflow lineage.

The ingestion pipeline (Code Fragment T.4.8 above) chunks documents with PySpark, writes chunks to a Delta table, and lets Vector Search handle embedding and indexing automatically.

The RAG chain itself (Code Fragment T.4.9 above) is a Python function wrapped as a pyfunc MLflow model. This makes it deployable to any MLflow-compatible endpoint, including Databricks Model Serving.

Key Insight

Wrapping the RAG chain as an MLflow pyfunc model provides more than just deployment convenience. It enables mlflow.evaluate() to run the full chain end-to-end against a labeled evaluation set, measuring retrieval quality (context recall, context precision) and generation quality (faithfulness, answer relevance) in one call. Evaluation with structured metrics is covered in Chapter 29.

Summary

Databricks provides a vertically integrated path from raw data to deployed AI application. Mosaic AI handles distributed training and managed fine-tuning, with MLflow capturing every experiment for reproducibility (see Appendix R). Foundation Model APIs give pay-per-token access to hosted open-weight models using the familiar OpenAI interface, with provisioned throughput available for latency-sensitive workloads. AI Functions bring LLM calls into SQL, making it practical to enrich data at ingestion time inside Delta Live Tables pipelines (see Section T.2). MLflow Evaluate provides LLM-specific metrics for comparing model versions before promotion through the Unity Catalog Model Registry. Vector Search adds a managed retrieval layer that auto-syncs with source Delta tables, removing the need for a separate vector database in most RAG architectures. Finally, all these components compose naturally into end-to-end RAG applications that are deployed, versioned, and evaluated entirely within the Lakehouse, as detailed in Chapter 20 and the evaluation framework of Chapter 29.