Training and fine-tuning LLMs requires processing corpora that range from tens of gigabytes to multiple terabytes. PySpark provides the distributed compute layer that makes this tractable: it parallelizes text cleaning, deduplication, tokenization, and embedding generation across hundreds of cores, and it integrates naturally with the Delta Lake storage layer covered in Section T.2 and the Databricks platform covered in Section T.1. This section walks through each stage of a PySpark-based LLM data pipeline, from reading raw corpora to writing training-ready Parquet files and populating vector databases.
T.1.1 PySpark Fundamentals for Text Data
Before processing any text, configure your SparkSession for the memory profiles that
LLM data pipelines demand. Text data is wide (long strings), so default Spark settings for shuffle
partitions and executor memory are often too conservative.
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, FloatType, IntegerType
# Configure SparkSession for LLM data workloads
spark = (
SparkSession.builder
.appName("llm-data-pipeline")
# Increase shuffle partitions for large text corpora
.config("spark.sql.shuffle.partitions", "800")
# Allow Spark to read ORC/Parquet metadata without loading full columns
.config("spark.sql.parquet.filterPushdown", "true")
# Use Kryo for faster serialization of Python objects
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
# Driver and executor memory: tune to your cluster
.config("spark.driver.memory", "16g")
.config("spark.executor.memory", "32g")
.config("spark.executor.memoryOverhead", "4g")
# Adaptive Query Execution (Spark 3+): auto-coalesces small partitions
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
# Read from Delta Lake (recommended for managed pipelines)
df_delta = spark.read.format("delta").load("s3://my-bucket/corpora/c4-en")
# Read from Parquet (e.g., downloaded Common Crawl snapshots)
df_parquet = spark.read.parquet("s3://my-bucket/raw/cc-2024-10/")
# Read from newline-delimited JSON
df_json = (
spark.read
.option("multiline", "false") # one JSON object per line
.json("s3://my-bucket/raw/dolma-v1.7/")
)
# Read from CSV with header (e.g., OpenWebText metadata)
df_csv = (
spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv("s3://my-bucket/raw/openwebtext-meta/")
)
# Quick schema inspection
df_delta.printSchema()
print(f"Row count: {df_delta.count():,}")
# Filter to English, require minimum text length and quality score
df_filtered = (
df_delta
.filter(F.col("language") == "en")
.filter(F.col("language_score") >= 0.85)
.filter(F.length("text") >= 200)
.filter(F.length("text") <= 100_000)
.filter(F.col("quality_score") >= 0.5)
)
# Register a UDF for language detection on raw text (when no pre-computed label)
from langdetect import detect, LangDetectException
def detect_lang(text: str) -> str:
if not text or len(text) < 20:
return "unknown"
try:
return detect(text)
except LangDetectException:
return "unknown"
detect_lang_udf = F.udf(detect_lang, StringType())
df_with_lang = df_json.withColumn("detected_lang", detect_lang_udf(F.col("text")))
df_en = df_with_lang.filter(F.col("detected_lang") == "en")
PySpark can read all common corpus formats. Delta Lake is preferred for its ACID guarantees and time-travel capabilities (see Section T.2), but raw Parquet and JSON are also common for externally sourced datasets.
Basic DataFrame operations for text quality gating run as distributed SQL-style transformations.
Language filtering, length thresholding, and quality score cuts are all expressible with
F.col predicates and are pushed down to the storage layer when reading Parquet.
import hashlib
from pyspark.ml.feature import MinHashLSH, HashingTF, Tokenizer
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType
# --- Exact deduplication via SHA-256 ---
def sha256_hash(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
hash_udf = F.udf(sha256_hash, StringType())
df_hashed = df_filtered.withColumn("doc_hash", hash_udf(F.col("text")))
# Keep only the first occurrence of each hash
from pyspark.sql.window import Window
window = Window.partitionBy("doc_hash").orderBy("url")
df_exact_deduped = (
df_hashed
.withColumn("rank", F.row_number().over(window))
.filter(F.col("rank") == 1)
.drop("rank")
)
print(f"After exact dedup: {df_exact_deduped.count():,} docs")
# --- Near-duplicate removal via MinHash LSH ---
# Tokenize into word shingles
tokenizer = Tokenizer(inputCol="text", outputCol="words")
df_words = tokenizer.transform(df_exact_deduped)
# Hash shingles into a feature vector
hashing_tf = HashingTF(inputCol="words", outputCol="features", numFeatures=262144)
df_features = hashing_tf.transform(df_words)
# Build MinHash LSH model
mh = MinHashLSH(inputCol="features", outputCol="hashes", numHashTables=5)
model = mh.fit(df_features)
# Find near-duplicate pairs (Jaccard similarity >= 0.8)
df_pairs = model.approxSimilarityJoin(
df_features, df_features,
threshold=0.2, # distance threshold = 1 - 0.8 Jaccard
distCol="distance"
)
# Extract IDs of duplicates to remove (keep lowest ID in each cluster)
duplicates_to_remove = (
df_pairs
.filter(F.col("datasetA.id") != F.col("datasetB.id"))
.select(F.greatest("datasetA.id", "datasetB.id").alias("remove_id"))
.distinct()
)
df_near_deduped = df_features.join(
duplicates_to_remove,
df_features["id"] == duplicates_to_remove["remove_id"],
how="left_anti"
)
Row-level Python UDFs incur Python-JVM serialization overhead for every row. For high-throughput pipelines, prefer pandas_udf (covered in T.1.3 and T.1.4) which processes batches of rows using Arrow, reducing overhead by 10-50x.
T.1.2 Large-Scale Text Preprocessing
Raw web corpora contain near-duplicate documents, malformed HTML, non-standard Unicode, and personally identifiable information (PII). Removing these before tokenization is essential for training data quality and legal compliance.
Deduplication
Exact deduplication uses a hash of the document content. Near-duplicate removal requires MinHash Locality Sensitive Hashing (LSH), which approximates Jaccard similarity between documents at scale.
from pyspark.sql.types import StructType, StructField
SEQ_LEN = 4096
EOS_ID = 2 # Llama EOS token ID
@pandas_udf(ArrayType(IntegerType()))
def pack_sequences(token_id_lists: pd.Series) -> pd.Series:
"""Greedily pack variable-length token sequences into fixed-length chunks."""
packed = []
current_chunk = []
for token_ids in token_id_lists:
if not token_ids:
continue
# Append EOS after each document
doc_tokens = list(token_ids) + [EOS_ID]
# Split document if longer than context window
while len(doc_tokens) > 0:
space = SEQ_LEN - len(current_chunk)
current_chunk.extend(doc_tokens[:space])
doc_tokens = doc_tokens[space:]
if len(current_chunk) == SEQ_LEN:
packed.append(current_chunk)
current_chunk = []
# Discard the final partial chunk (do not pad)
return pd.Series(packed)
df_packed = df_tokenized.groupBy("partition_id").agg(
pack_sequences(F.col("token_ids")).alias("packed_sequences")
)
pack_sequences UDF concatenates variable-length documents with EOS separators into fixed-length windows of SEQ_LEN tokens, improving GPU utilization by 15 to 30% compared to per-document truncation with padding.Text Cleaning and PII Redaction
import re
import unicodedata
from pyspark.sql.functions import pandas_udf
import pandas as pd
# Regex patterns for PII and noise
EMAIL_RE = re.compile(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+")
PHONE_RE = re.compile(r"\b(\+?1[\s.-]?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b")
HTML_TAG_RE = re.compile(r"<[^>]+>")
MULTI_SPACE_RE = re.compile(r"\s+")
@pandas_udf(StringType())
def clean_text(texts: pd.Series) -> pd.Series:
def _clean(text: str) -> str:
if not isinstance(text, str):
return ""
# Strip HTML tags
text = HTML_TAG_RE.sub(" ", text)
# Normalize Unicode to NFC form
text = unicodedata.normalize("NFC", text)
# Redact PII
text = EMAIL_RE.sub("[EMAIL]", text)
text = PHONE_RE.sub("[PHONE]", text)
# Collapse whitespace
text = MULTI_SPACE_RE.sub(" ", text).strip()
return text
return texts.apply(_clean)
df_cleaned = df_near_deduped.withColumn("text", clean_text(F.col("text")))
pandas_udf. Processing a Pandas Series per partition is significantly faster than row-level UDFs for regex-heavy operations.from transformers import AutoTokenizer
from pyspark.sql.types import ArrayType, IntegerType
import pandas as pd
MODEL_NAME = "meta-llama/Llama-3.1-8B"
TOKENIZER_BROADCAST = None # populated per-partition to avoid serialization
@pandas_udf(ArrayType(IntegerType()))
def tokenize_text(texts: pd.Series) -> pd.Series:
# Load tokenizer once per executor process (cached after first call)
global TOKENIZER_BROADCAST
if TOKENIZER_BROADCAST is None:
TOKENIZER_BROADCAST = AutoTokenizer.from_pretrained(
MODEL_NAME, use_fast=True
)
tok = TOKENIZER_BROADCAST
# Tokenize the whole batch at once (fast tokenizers support batching)
results = tok(
texts.tolist(),
truncation=False,
padding=False,
add_special_tokens=False,
)
return pd.Series(results["input_ids"])
df_tokenized = df_cleaned.withColumn("token_ids", tokenize_text(F.col("text")))
pandas_udf. The tokenizer is loaded once per executor process using a global variable. Using use_fast=True enables the Rust-backed tokenizer, which is 5 to 10x faster than the Python implementation.The C4 dataset (used to train T5 and many subsequent models) applies a sequence of quality filters that are straightforward to replicate in PySpark:
- Keep only lines ending in terminal punctuation (
.,!,?,"). - Remove documents with fewer than 5 sentences or fewer than 3 words per line on average.
- Remove documents containing any of a list of profanity/toxic keywords.
- Remove lines containing the substring
javascript(catches boilerplate web page noise).
Each filter is a simple F.col predicate or a single-pass UDF and adds negligible compute cost relative to the I/O required to read the corpus.
T.1.3 Tokenization and Dataset Preparation
Once text is cleaned and deduplicated, the next step is converting it into token sequences that
training frameworks can consume. HuggingFace tokenizers are the de facto standard; applying them
at scale requires wrapping them in a pandas_udf so the tokenizer is loaded once per
Spark partition rather than once per row.
# Write packed sequences to Parquet with tuned row group sizes
OUTPUT_PATH = "s3://my-bucket/training/llama-sft-packed-v3"
(
df_packed
.write
.mode("overwrite")
.option("parquet.block.size", 128 * 1024 * 1024) # 128 MB row groups
.option("parquet.page.size", 1 * 1024 * 1024) # 1 MB pages
.parquet(OUTPUT_PATH)
)
# --- In your training script, load directly as a HuggingFace Dataset ---
from datasets import load_dataset
dataset = load_dataset(
"parquet",
data_files={"train": f"{OUTPUT_PATH}/*.parquet"},
split="train",
num_proc=8,
)
print(dataset)
# Dataset({features: ['packed_sequences'], num_rows: 4200000})
Dataset via the parquet loader. The 128 MB row group setting balances metadata overhead against read parallelism.Training efficiency requires packing multiple short documents into fixed-length sequences to avoid wasting context window capacity on padding tokens. The following UDF implements a greedy bin-packing approach per partition.
import torch
import numpy as np
import pandas as pd
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql.functions import pandas_udf, col
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
EMBED_DIM = 1024
BATCH_SIZE = 256 # tune to GPU memory
# Module-level cache: populated once per executor JVM process
_model_cache: dict = {}
@pandas_udf(ArrayType(FloatType()))
def embed_texts(texts: pd.Series) -> pd.Series:
global _model_cache
if EMBEDDING_MODEL not in _model_cache:
device = "cuda" if torch.cuda.is_available() else "cpu"
_model_cache[EMBEDDING_MODEL] = SentenceTransformer(
EMBEDDING_MODEL, device=device
)
model = _model_cache[EMBEDDING_MODEL]
# Process in mini-batches to avoid OOM on large partitions
all_embeddings = []
text_list = texts.tolist()
for i in range(0, len(text_list), BATCH_SIZE):
batch = text_list[i : i + BATCH_SIZE]
embs = model.encode(
batch,
batch_size=len(batch),
normalize_embeddings=True,
show_progress_bar=False,
)
all_embeddings.extend(embs.tolist())
return pd.Series(all_embeddings)
# Repartition to align one partition per GPU worker
df_embedded = (
df_cleaned
.select("id", "text", "url")
.repartition(spark.sparkContext.defaultParallelism)
.withColumn("embedding", embed_texts(col("text")))
)
_model_cache dict ensures the model is loaded once per executor process, regardless of how many tasks that executor runs. Repartitioning to match GPU count minimizes model-load overhead.Converting to HuggingFace Datasets
The most efficient conversion path from Spark to a HuggingFace Dataset uses Apache
Arrow, which both Spark and HuggingFace natively support. Write Parquet from Spark, then load
with datasets.load_dataset.
Row group size directly affects streaming throughput. Groups that are too small increase metadata overhead; groups that are too large reduce parallelism when reading. The 128 MB default is a good starting point for single-GPU training. See Section T.2 for Delta Lake-specific Parquet optimization settings.
T.1.4 Embedding Generation at Scale
Generating embeddings for millions of documents is one of the most GPU-intensive offline tasks in
LLM infrastructure. The key challenge is loading the embedding model exactly once per Spark
executor (not once per row or even once per batch), and using Arrow-based pandas_udf
to send large batches to the GPU efficiently. This pattern connects directly to the vector search
infrastructure covered in Chapter 19.
Writing Embeddings and Ingesting into Vector Databases
# Write embeddings to Parquet for offline use
EMBED_PATH = "s3://my-bucket/embeddings/bge-large-v1"
(
df_embedded
.write
.mode("overwrite")
.parquet(EMBED_PATH)
)
# --- Batch upsert to Pinecone ---
import pinecone
from pyspark.sql.functions import struct
PINECONE_API_KEY = "..."
INDEX_NAME = "docs-bge-large"
def upsert_to_pinecone(rows):
"""Called once per Spark partition to batch-upsert embeddings."""
pc = pinecone.Pinecone(api_key=PINECONE_API_KEY)
index = pc.Index(INDEX_NAME)
batch = []
for row in rows:
batch.append({
"id": row["id"],
"values": row["embedding"],
"metadata": {"url": row["url"]},
})
if len(batch) == 200:
index.upsert(vectors=batch)
batch = []
if batch:
index.upsert(vectors=batch)
# foreachPartition loads the Pinecone client once per partition
df_embedded.foreachPartition(upsert_to_pinecone)
foreachPartition. The same partition-level pattern works for Weaviate (using weaviate.batch) and Milvus (using pymilvus.Collection.insert).Use foreachPartition rather than foreach when writing to external systems like vector databases. foreach calls your function once per row, which causes one network round-trip per document. foreachPartition calls your function once per partition with an iterator, enabling batched writes and connection reuse that reduce total ingestion time by orders of magnitude.
T.1.5 Monitoring and Optimizing Spark Jobs
Even well-designed PySpark pipelines require tuning when run on real corpora. The Spark UI is the primary diagnostic tool; understanding its output is essential for identifying and fixing the bottlenecks most common in LLM data pipelines.
Using the Spark UI
The Spark UI is available at http://<driver-host>:4040 during a running job. On
Databricks, it is accessible via the cluster detail page. The most useful views for LLM pipeline
debugging are:
- Stages tab: shows per-stage task duration distributions. A long tail in the distribution indicates partition skew (one partition is much larger than others).
- SQL tab: shows the physical query plan with operator-level timing. Look for
BroadcastNestedLoopJoinas a warning sign that a join predicate was missed. - Executors tab: shows GC time per executor. GC time above 10% of task time usually indicates that executor heap size needs to increase or that large objects are being created per row.
- Storage tab: shows cached RDD/DataFrame sizes. If a DataFrame you intend to
reuse is not cached, adding
.cache()before the second use avoids recomputing it.
Common Bottlenecks in LLM Data Pipelines
# --- Diagnosing and fixing partition skew ---
# Check partition sizes after a wide transformation
from pyspark.sql.functions import spark_partition_id
partition_sizes = (
df_cleaned
.withColumn("pid", spark_partition_id())
.groupBy("pid")
.count()
.orderBy("count", ascending=False)
)
partition_sizes.show(20)
# If max partition >> median, salt the skewed key
# Example: language="en" dominates. Salt with a random integer.
import pyspark.sql.functions as F
SALT_BUCKETS = 50
df_salted = df_cleaned.withColumn(
"lang_salted",
F.concat(F.col("language"), F.lit("_"), (F.rand() * SALT_BUCKETS).cast("int"))
)
# --- Broadcast join for small lookup tables ---
# Use when one side of a join is small (e.g., a domain blocklist)
from pyspark.sql.functions import broadcast
df_blocklist = spark.read.parquet("s3://my-bucket/blocklist/") # small table
df_clean_domains = df_cleaned.join(
broadcast(df_blocklist),
on="domain",
how="left_anti" # exclude rows matching the blocklist
)
# --- Tune partition count for large shuffles ---
# Rule of thumb: target 100-200 MB per partition after shuffle
# For a 2 TB corpus with 200 MB target:
spark.conf.set("spark.sql.shuffle.partitions", str(2000 * 1024 // 200))
Python UDFs serialize each row as a Python object, cross the JVM-Python boundary, execute the Python function, and serialize the result back. On a 1 TB corpus this overhead can be the dominant cost in your pipeline. Always profile with the Spark UI before adding UDFs, and prefer built-in Spark SQL functions (F.regexp_replace, F.length, etc.) when they cover your use case. Reserve Python UDFs for logic that genuinely cannot be expressed with built-ins, and use pandas_udf for everything else.
Cost Optimization
# --- Databricks cluster autoscaling configuration (JSON, not Python) ---
# Set in cluster creation API or Databricks UI
CLUSTER_CONFIG = {
"cluster_name": "llm-data-pipeline",
"spark_version": "15.4.x-scala2.12",
"node_type_id": "i3.2xlarge",
"autoscale": {
"min_workers": 4,
"max_workers": 40,
},
"aws_attributes": {
"availability": "SPOT_WITH_FALLBACK",
"spot_bid_price_percent": 100,
"first_on_demand": 2, # keep 2 on-demand nodes as stable driver/coordinator
},
"spark_conf": {
"spark.sql.adaptive.enabled": "true",
"spark.sql.adaptive.coalescePartitions.enabled": "true",
},
}
# --- Checkpoint long pipelines to avoid recomputation on failure ---
spark.sparkContext.setCheckpointDir("s3://my-bucket/checkpoints/")
# Checkpoint after the expensive deduplication stage
df_near_deduped.checkpoint() # materializes to S3, breaks lineage graph
# --- Cache DataFrames used more than once ---
df_tokenized.cache()
df_tokenized.count() # trigger materialization immediately
For pipelines that process a corpus incrementally (new data arriving daily), use Delta Lake's MERGE INTO operation rather than reprocessing the entire corpus each run. Combined with Delta's change data feed, you can identify only new or updated documents and pass them through the expensive deduplication and tokenization stages. This is covered in detail in Section T.2.
Summary
PySpark provides the distributed compute foundation for every stage of an LLM data pipeline.
The SparkSession configuration choices in T.1.1 set the foundation for throughput and
memory stability. Deduplication via exact hashing and MinHash LSH in T.1.2 is the most impactful
single quality intervention for pretraining data. The pandas_udf pattern for
tokenization and embedding generation in T.1.3 and T.1.4 bridges PySpark with the HuggingFace
and SentenceTransformer ecosystems without sacrificing throughput. Finally, the Spark UI
diagnostics and tuning patterns in T.1.5 give you the tools to identify and eliminate the
bottlenecks that emerge when pipelines move from prototype to production scale.
These techniques integrate directly with the Databricks platform (Section T.1), Delta Lake storage and ACID semantics (Section T.2), and the HuggingFace Datasets library for downstream training (Appendix K). For the vector search applications that consume the embeddings produced here, see Chapter 19.