Building Conversational AI with LLMs and Agents
Appendix T: Distributed ML: PySpark, Databricks, and Ray

PySpark for LLM Data Pipelines

Big Picture

Training and fine-tuning LLMs requires processing corpora that range from tens of gigabytes to multiple terabytes. PySpark provides the distributed compute layer that makes this tractable: it parallelizes text cleaning, deduplication, tokenization, and embedding generation across hundreds of cores, and it integrates naturally with the Delta Lake storage layer covered in Section T.2 and the Databricks platform covered in Section T.1. This section walks through each stage of a PySpark-based LLM data pipeline, from reading raw corpora to writing training-ready Parquet files and populating vector databases.

T.1.1 PySpark Fundamentals for Text Data

Before processing any text, configure your SparkSession for the memory profiles that LLM data pipelines demand. Text data is wide (long strings), so default Spark settings for shuffle partitions and executor memory are often too conservative.

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, FloatType, IntegerType

# Configure SparkSession for LLM data workloads
spark = (
    SparkSession.builder
    .appName("llm-data-pipeline")
    # Increase shuffle partitions for large text corpora
    .config("spark.sql.shuffle.partitions", "800")
    # Allow Spark to read ORC/Parquet metadata without loading full columns
    .config("spark.sql.parquet.filterPushdown", "true")
    # Use Kryo for faster serialization of Python objects
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    # Driver and executor memory: tune to your cluster
    .config("spark.driver.memory", "16g")
    .config("spark.executor.memory", "32g")
    .config("spark.executor.memoryOverhead", "4g")
    # Adaptive Query Execution (Spark 3+): auto-coalesces small partitions
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")
# Read from Delta Lake (recommended for managed pipelines)
df_delta = spark.read.format("delta").load("s3://my-bucket/corpora/c4-en")

# Read from Parquet (e.g., downloaded Common Crawl snapshots)
df_parquet = spark.read.parquet("s3://my-bucket/raw/cc-2024-10/")

# Read from newline-delimited JSON
df_json = (
    spark.read
    .option("multiline", "false")   # one JSON object per line
    .json("s3://my-bucket/raw/dolma-v1.7/")
)

# Read from CSV with header (e.g., OpenWebText metadata)
df_csv = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "true")
    .csv("s3://my-bucket/raw/openwebtext-meta/")
)

# Quick schema inspection
df_delta.printSchema()
print(f"Row count: {df_delta.count():,}")
# Filter to English, require minimum text length and quality score
df_filtered = (
    df_delta
    .filter(F.col("language") == "en")
    .filter(F.col("language_score") >= 0.85)
    .filter(F.length("text") >= 200)
    .filter(F.length("text") <= 100_000)
    .filter(F.col("quality_score") >= 0.5)
)

# Register a UDF for language detection on raw text (when no pre-computed label)
from langdetect import detect, LangDetectException

def detect_lang(text: str) -> str:
    if not text or len(text) < 20:
        return "unknown"
    try:
        return detect(text)
    except LangDetectException:
        return "unknown"

detect_lang_udf = F.udf(detect_lang, StringType())

df_with_lang = df_json.withColumn("detected_lang", detect_lang_udf(F.col("text")))
df_en = df_with_lang.filter(F.col("detected_lang") == "en")
Code Fragment T.1.1: End-to-end corpus ingestion: SparkSession configuration, reading multiple formats (Delta, Parquet, JSON, CSV), and language/quality filtering. Adaptive Query Execution auto-coalesces small partitions when corpus sizes vary widely.

PySpark can read all common corpus formats. Delta Lake is preferred for its ACID guarantees and time-travel capabilities (see Section T.2), but raw Parquet and JSON are also common for externally sourced datasets.

Basic DataFrame operations for text quality gating run as distributed SQL-style transformations. Language filtering, length thresholding, and quality score cuts are all expressible with F.col predicates and are pushed down to the storage layer when reading Parquet.

import hashlib
from pyspark.ml.feature import MinHashLSH, HashingTF, Tokenizer
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType

# --- Exact deduplication via SHA-256 ---
def sha256_hash(text: str) -> str:
    return hashlib.sha256(text.encode("utf-8")).hexdigest()

hash_udf = F.udf(sha256_hash, StringType())

df_hashed = df_filtered.withColumn("doc_hash", hash_udf(F.col("text")))

# Keep only the first occurrence of each hash
from pyspark.sql.window import Window

window = Window.partitionBy("doc_hash").orderBy("url")
df_exact_deduped = (
    df_hashed
    .withColumn("rank", F.row_number().over(window))
    .filter(F.col("rank") == 1)
    .drop("rank")
)

print(f"After exact dedup: {df_exact_deduped.count():,} docs")
# --- Near-duplicate removal via MinHash LSH ---
# Tokenize into word shingles
tokenizer = Tokenizer(inputCol="text", outputCol="words")
df_words = tokenizer.transform(df_exact_deduped)

# Hash shingles into a feature vector
hashing_tf = HashingTF(inputCol="words", outputCol="features", numFeatures=262144)
df_features = hashing_tf.transform(df_words)

# Build MinHash LSH model
mh = MinHashLSH(inputCol="features", outputCol="hashes", numHashTables=5)
model = mh.fit(df_features)

# Find near-duplicate pairs (Jaccard similarity >= 0.8)
df_pairs = model.approxSimilarityJoin(
    df_features, df_features,
    threshold=0.2,    # distance threshold = 1 - 0.8 Jaccard
    distCol="distance"
)

# Extract IDs of duplicates to remove (keep lowest ID in each cluster)
duplicates_to_remove = (
    df_pairs
    .filter(F.col("datasetA.id") != F.col("datasetB.id"))
    .select(F.greatest("datasetA.id", "datasetB.id").alias("remove_id"))
    .distinct()
)

df_near_deduped = df_features.join(
    duplicates_to_remove,
    df_features["id"] == duplicates_to_remove["remove_id"],
    how="left_anti"
)
Code Fragment T.1.2: Two-stage deduplication: exact removal via SHA-256 hashing with window functions, then near-duplicate removal via MinHash LSH. The distance threshold of 0.2 corresponds to Jaccard similarity above 0.8.
Note

Row-level Python UDFs incur Python-JVM serialization overhead for every row. For high-throughput pipelines, prefer pandas_udf (covered in T.1.3 and T.1.4) which processes batches of rows using Arrow, reducing overhead by 10-50x.

T.1.2 Large-Scale Text Preprocessing

Raw web corpora contain near-duplicate documents, malformed HTML, non-standard Unicode, and personally identifiable information (PII). Removing these before tokenization is essential for training data quality and legal compliance.

Deduplication

Exact deduplication uses a hash of the document content. Near-duplicate removal requires MinHash Locality Sensitive Hashing (LSH), which approximates Jaccard similarity between documents at scale.

from pyspark.sql.types import StructType, StructField

SEQ_LEN = 4096
EOS_ID = 2  # Llama EOS token ID

@pandas_udf(ArrayType(IntegerType()))
def pack_sequences(token_id_lists: pd.Series) -> pd.Series:
    """Greedily pack variable-length token sequences into fixed-length chunks."""
    packed = []
    current_chunk = []

    for token_ids in token_id_lists:
        if not token_ids:
            continue
        # Append EOS after each document
        doc_tokens = list(token_ids) + [EOS_ID]

        # Split document if longer than context window
        while len(doc_tokens) > 0:
            space = SEQ_LEN - len(current_chunk)
            current_chunk.extend(doc_tokens[:space])
            doc_tokens = doc_tokens[space:]

            if len(current_chunk) == SEQ_LEN:
                packed.append(current_chunk)
                current_chunk = []

    # Discard the final partial chunk (do not pad)
    return pd.Series(packed)

df_packed = df_tokenized.groupBy("partition_id").agg(
    pack_sequences(F.col("token_ids")).alias("packed_sequences")
)
Code Fragment T.1.3: Greedy sequence packing. The pack_sequences UDF concatenates variable-length documents with EOS separators into fixed-length windows of SEQ_LEN tokens, improving GPU utilization by 15 to 30% compared to per-document truncation with padding.

Text Cleaning and PII Redaction

import re
import unicodedata
from pyspark.sql.functions import pandas_udf
import pandas as pd

# Regex patterns for PII and noise
EMAIL_RE = re.compile(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+")
PHONE_RE = re.compile(r"\b(\+?1[\s.-]?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b")
HTML_TAG_RE = re.compile(r"<[^>]+>")
MULTI_SPACE_RE = re.compile(r"\s+")

@pandas_udf(StringType())
def clean_text(texts: pd.Series) -> pd.Series:
    def _clean(text: str) -> str:
        if not isinstance(text, str):
            return ""
        # Strip HTML tags
        text = HTML_TAG_RE.sub(" ", text)
        # Normalize Unicode to NFC form
        text = unicodedata.normalize("NFC", text)
        # Redact PII
        text = EMAIL_RE.sub("[EMAIL]", text)
        text = PHONE_RE.sub("[PHONE]", text)
        # Collapse whitespace
        text = MULTI_SPACE_RE.sub(" ", text).strip()
        return text

    return texts.apply(_clean)

df_cleaned = df_near_deduped.withColumn("text", clean_text(F.col("text")))
Code Fragment T.1.4: Vectorized text cleaning via pandas_udf. Processing a Pandas Series per partition is significantly faster than row-level UDFs for regex-heavy operations.
from transformers import AutoTokenizer
from pyspark.sql.types import ArrayType, IntegerType
import pandas as pd

MODEL_NAME = "meta-llama/Llama-3.1-8B"
TOKENIZER_BROADCAST = None  # populated per-partition to avoid serialization

@pandas_udf(ArrayType(IntegerType()))
def tokenize_text(texts: pd.Series) -> pd.Series:
    # Load tokenizer once per executor process (cached after first call)
    global TOKENIZER_BROADCAST
    if TOKENIZER_BROADCAST is None:
        TOKENIZER_BROADCAST = AutoTokenizer.from_pretrained(
            MODEL_NAME, use_fast=True
        )
    tok = TOKENIZER_BROADCAST

    # Tokenize the whole batch at once (fast tokenizers support batching)
    results = tok(
        texts.tolist(),
        truncation=False,
        padding=False,
        add_special_tokens=False,
    )
    return pd.Series(results["input_ids"])

df_tokenized = df_cleaned.withColumn("token_ids", tokenize_text(F.col("text")))
Code Fragment T.1.5: Batch tokenization via pandas_udf. The tokenizer is loaded once per executor process using a global variable. Using use_fast=True enables the Rust-backed tokenizer, which is 5 to 10x faster than the Python implementation.
Practical Example: C4-Style Preprocessing Pipeline

The C4 dataset (used to train T5 and many subsequent models) applies a sequence of quality filters that are straightforward to replicate in PySpark:

Each filter is a simple F.col predicate or a single-pass UDF and adds negligible compute cost relative to the I/O required to read the corpus.

T.1.3 Tokenization and Dataset Preparation

Once text is cleaned and deduplicated, the next step is converting it into token sequences that training frameworks can consume. HuggingFace tokenizers are the de facto standard; applying them at scale requires wrapping them in a pandas_udf so the tokenizer is loaded once per Spark partition rather than once per row.

# Write packed sequences to Parquet with tuned row group sizes
OUTPUT_PATH = "s3://my-bucket/training/llama-sft-packed-v3"

(
    df_packed
    .write
    .mode("overwrite")
    .option("parquet.block.size", 128 * 1024 * 1024)   # 128 MB row groups
    .option("parquet.page.size",  1 * 1024 * 1024)     # 1 MB pages
    .parquet(OUTPUT_PATH)
)

# --- In your training script, load directly as a HuggingFace Dataset ---
from datasets import load_dataset

dataset = load_dataset(
    "parquet",
    data_files={"train": f"{OUTPUT_PATH}/*.parquet"},
    split="train",
    num_proc=8,
)
print(dataset)
# Dataset({features: ['packed_sequences'], num_rows: 4200000})
Code Fragment T.1.6: Writing packed sequences to Parquet with tuned row group sizes, then loading them as a HuggingFace Dataset via the parquet loader. The 128 MB row group setting balances metadata overhead against read parallelism.

Training efficiency requires packing multiple short documents into fixed-length sequences to avoid wasting context window capacity on padding tokens. The following UDF implements a greedy bin-packing approach per partition.

import torch
import numpy as np
import pandas as pd
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql.functions import pandas_udf, col
from sentence_transformers import SentenceTransformer

EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
EMBED_DIM = 1024
BATCH_SIZE = 256   # tune to GPU memory

# Module-level cache: populated once per executor JVM process
_model_cache: dict = {}

@pandas_udf(ArrayType(FloatType()))
def embed_texts(texts: pd.Series) -> pd.Series:
    global _model_cache
    if EMBEDDING_MODEL not in _model_cache:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        _model_cache[EMBEDDING_MODEL] = SentenceTransformer(
            EMBEDDING_MODEL, device=device
        )
    model = _model_cache[EMBEDDING_MODEL]

    # Process in mini-batches to avoid OOM on large partitions
    all_embeddings = []
    text_list = texts.tolist()
    for i in range(0, len(text_list), BATCH_SIZE):
        batch = text_list[i : i + BATCH_SIZE]
        embs = model.encode(
            batch,
            batch_size=len(batch),
            normalize_embeddings=True,
            show_progress_bar=False,
        )
        all_embeddings.extend(embs.tolist())

    return pd.Series(all_embeddings)

# Repartition to align one partition per GPU worker
df_embedded = (
    df_cleaned
    .select("id", "text", "url")
    .repartition(spark.sparkContext.defaultParallelism)
    .withColumn("embedding", embed_texts(col("text")))
)
Code Fragment T.1.7: GPU embedding generation at scale. The module-level _model_cache dict ensures the model is loaded once per executor process, regardless of how many tasks that executor runs. Repartitioning to match GPU count minimizes model-load overhead.

Converting to HuggingFace Datasets

The most efficient conversion path from Spark to a HuggingFace Dataset uses Apache Arrow, which both Spark and HuggingFace natively support. Write Parquet from Spark, then load with datasets.load_dataset.

Note

Row group size directly affects streaming throughput. Groups that are too small increase metadata overhead; groups that are too large reduce parallelism when reading. The 128 MB default is a good starting point for single-GPU training. See Section T.2 for Delta Lake-specific Parquet optimization settings.

T.1.4 Embedding Generation at Scale

Generating embeddings for millions of documents is one of the most GPU-intensive offline tasks in LLM infrastructure. The key challenge is loading the embedding model exactly once per Spark executor (not once per row or even once per batch), and using Arrow-based pandas_udf to send large batches to the GPU efficiently. This pattern connects directly to the vector search infrastructure covered in Chapter 19.

Writing Embeddings and Ingesting into Vector Databases

# Write embeddings to Parquet for offline use
EMBED_PATH = "s3://my-bucket/embeddings/bge-large-v1"

(
    df_embedded
    .write
    .mode("overwrite")
    .parquet(EMBED_PATH)
)

# --- Batch upsert to Pinecone ---
import pinecone
from pyspark.sql.functions import struct

PINECONE_API_KEY = "..."
INDEX_NAME = "docs-bge-large"

def upsert_to_pinecone(rows):
    """Called once per Spark partition to batch-upsert embeddings."""
    pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
    index = pc.Index(INDEX_NAME)

    batch = []
    for row in rows:
        batch.append({
            "id": row["id"],
            "values": row["embedding"],
            "metadata": {"url": row["url"]},
        })
        if len(batch) == 200:
            index.upsert(vectors=batch)
            batch = []
    if batch:
        index.upsert(vectors=batch)

# foreachPartition loads the Pinecone client once per partition
df_embedded.foreachPartition(upsert_to_pinecone)
Code Fragment T.1.8: Writing embeddings to Parquet for offline use and batch upserting to Pinecone via foreachPartition. The same partition-level pattern works for Weaviate (using weaviate.batch) and Milvus (using pymilvus.Collection.insert).
Key Insight

Use foreachPartition rather than foreach when writing to external systems like vector databases. foreach calls your function once per row, which causes one network round-trip per document. foreachPartition calls your function once per partition with an iterator, enabling batched writes and connection reuse that reduce total ingestion time by orders of magnitude.

T.1.5 Monitoring and Optimizing Spark Jobs

Even well-designed PySpark pipelines require tuning when run on real corpora. The Spark UI is the primary diagnostic tool; understanding its output is essential for identifying and fixing the bottlenecks most common in LLM data pipelines.

Using the Spark UI

The Spark UI is available at http://<driver-host>:4040 during a running job. On Databricks, it is accessible via the cluster detail page. The most useful views for LLM pipeline debugging are:

Common Bottlenecks in LLM Data Pipelines

# --- Diagnosing and fixing partition skew ---
# Check partition sizes after a wide transformation
from pyspark.sql.functions import spark_partition_id

partition_sizes = (
    df_cleaned
    .withColumn("pid", spark_partition_id())
    .groupBy("pid")
    .count()
    .orderBy("count", ascending=False)
)
partition_sizes.show(20)

# If max partition >> median, salt the skewed key
# Example: language="en" dominates. Salt with a random integer.
import pyspark.sql.functions as F

SALT_BUCKETS = 50
df_salted = df_cleaned.withColumn(
    "lang_salted",
    F.concat(F.col("language"), F.lit("_"), (F.rand() * SALT_BUCKETS).cast("int"))
)

# --- Broadcast join for small lookup tables ---
# Use when one side of a join is small (e.g., a domain blocklist)
from pyspark.sql.functions import broadcast

df_blocklist = spark.read.parquet("s3://my-bucket/blocklist/")  # small table

df_clean_domains = df_cleaned.join(
    broadcast(df_blocklist),
    on="domain",
    how="left_anti"  # exclude rows matching the blocklist
)

# --- Tune partition count for large shuffles ---
# Rule of thumb: target 100-200 MB per partition after shuffle
# For a 2 TB corpus with 200 MB target:
spark.conf.set("spark.sql.shuffle.partitions", str(2000 * 1024 // 200))
Code Fragment T.1.9: Diagnosing partition skew, applying salting, using broadcast joins for small lookup tables, and setting shuffle partition count. These three techniques resolve the majority of Spark performance issues in text pipelines.
Warning

Python UDFs serialize each row as a Python object, cross the JVM-Python boundary, execute the Python function, and serialize the result back. On a 1 TB corpus this overhead can be the dominant cost in your pipeline. Always profile with the Spark UI before adding UDFs, and prefer built-in Spark SQL functions (F.regexp_replace, F.length, etc.) when they cover your use case. Reserve Python UDFs for logic that genuinely cannot be expressed with built-ins, and use pandas_udf for everything else.

Cost Optimization

# --- Databricks cluster autoscaling configuration (JSON, not Python) ---
# Set in cluster creation API or Databricks UI
CLUSTER_CONFIG = {
    "cluster_name": "llm-data-pipeline",
    "spark_version": "15.4.x-scala2.12",
    "node_type_id": "i3.2xlarge",
    "autoscale": {
        "min_workers": 4,
        "max_workers": 40,
    },
    "aws_attributes": {
        "availability": "SPOT_WITH_FALLBACK",
        "spot_bid_price_percent": 100,
        "first_on_demand": 2,   # keep 2 on-demand nodes as stable driver/coordinator
    },
    "spark_conf": {
        "spark.sql.adaptive.enabled": "true",
        "spark.sql.adaptive.coalescePartitions.enabled": "true",
    },
}

# --- Checkpoint long pipelines to avoid recomputation on failure ---
spark.sparkContext.setCheckpointDir("s3://my-bucket/checkpoints/")

# Checkpoint after the expensive deduplication stage
df_near_deduped.checkpoint()   # materializes to S3, breaks lineage graph

# --- Cache DataFrames used more than once ---
df_tokenized.cache()
df_tokenized.count()    # trigger materialization immediately
Code Fragment T.1.10: Cost optimization via spot instances with on-demand fallback, adaptive autoscaling, checkpointing to break long lineage graphs, and explicit caching for reused DataFrames.
Tip

For pipelines that process a corpus incrementally (new data arriving daily), use Delta Lake's MERGE INTO operation rather than reprocessing the entire corpus each run. Combined with Delta's change data feed, you can identify only new or updated documents and pass them through the expensive deduplication and tokenization stages. This is covered in detail in Section T.2.

Summary

PySpark provides the distributed compute foundation for every stage of an LLM data pipeline. The SparkSession configuration choices in T.1.1 set the foundation for throughput and memory stability. Deduplication via exact hashing and MinHash LSH in T.1.2 is the most impactful single quality intervention for pretraining data. The pandas_udf pattern for tokenization and embedding generation in T.1.3 and T.1.4 bridges PySpark with the HuggingFace and SentenceTransformer ecosystems without sacrificing throughput. Finally, the Spark UI diagnostics and tuning patterns in T.1.5 give you the tools to identify and eliminate the bottlenecks that emerge when pipelines move from prototype to production scale.

These techniques integrate directly with the Databricks platform (Section T.1), Delta Lake storage and ACID semantics (Section T.2), and the HuggingFace Datasets library for downstream training (Appendix K). For the vector search applications that consume the embeddings produced here, see Chapter 19.