{"id":410,"date":"2026-05-01T00:00:00","date_gmt":"2026-05-01T00:00:00","guid":{"rendered":"https:\/\/rjbarrett.redirectme.net\/?p=410"},"modified":"2026-05-01T00:00:00","modified_gmt":"2026-05-01T00:00:00","slug":"txpert-using-multiple-knowledge-graphs-for-prediction-of-transcriptomic-perturbation-effects-nature-biotechnology","status":"publish","type":"post","link":"https:\/\/rjbarrett.redirectme.net\/?p=410","title":{"rendered":"TxPert: using multiple knowledge graphs for prediction of transcriptomic perturbation effects &#8211; Nature Biotechnology"},"content":{"rendered":"<p><br \/>\n<\/p>\n<div id=\"Sec14-content\">\n<h3 class=\"c-article__sub-heading\" id=\"Sec15\">TxPert architecture<\/h3>\n<p>TxPert predicts the transcriptional response <span class=\"mathjax-tex\">\\({\\bf{y}}\\in {\\mathcal{Y}}\\subset {{\\mathbb{R}}}^{n}\\)<\/span> given a set of perturbation tokens <span class=\"mathjax-tex\">\\(P\\subset {\\mathcal{P}}:=\\{1,\\ldots ,N\\}\\)<\/span> and a basal state representation derived from a control expression profile <span class=\"mathjax-tex\">\\({\\bf{x}}\\in {\\mathcal{X}}\\subset {{\\mathbb{R}}}^{n}\\)<\/span>, which has been aligned with the predicted cell with respect to cell line and batch effect. Here, <span class=\"mathjax-tex\">\\(n\\in {\\mathbb{N}}\\)<\/span> denotes the number of experimentally measured genes and <span class=\"mathjax-tex\">\\(N\\in {\\mathbb{N}}\\)<\/span> denotes the total number of observed perturbations in the data. These perturbation tokens (or identifiers) are used to select node representations from a biological gene (product)\u2013gene (product) interaction KG whose embeddings are integrated with the basal state to produce the perturbed expression profile.<\/p>\n<p>To combine the information of the basal state and the perturbations, we first learn latent representations of both, that is, <span class=\"mathjax-tex\">\\({\\bf{x}}\\mapsto {\\bf{s}}\\in {{\\mathbb{R}}}^{d}\\)<\/span> and <span class=\"mathjax-tex\">\\(P\\mapsto \\{{{\\bf{z}}}_{p}\\in {{\\mathbb{R}}}^{d}:p\\in P\\}\\)<\/span> for a chosen latent dimension <span class=\"mathjax-tex\">\\(d\\in {\\mathbb{N}}\\)<\/span>. Then, we combine the information through latent shift, where a learned decoder <i>g<\/i><sub><i>\u03d5<\/i><\/sub> predicts the perturbation effect from the given context, that is, <span class=\"mathjax-tex\">\\(\\widehat{{\\bf{y}}}={g}_{\\phi }({\\bf{s}}+{\\sum }_{p\\in P}{{\\bf{z}}}_{p})\\)<\/span>, using the mean squared error (MSE):<\/p>\n<div id=\"Equa\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$${\\mathcal{L}}({\\bf{y}},\\widehat{{\\bf{y}}})=\\frac{1}{n}| | {\\bf{y}}-\\widehat{{\\bf{y}}}| {| }_{2}^{2}$$<\/span><\/p>\n<\/div>\n<p>This setup naturally integrates with single, double and more general multiperturbation settings of order <i>m<\/i>:\u2009=\u2009<span class=\"stix\">\u2223<\/span><i>P<\/i><span class=\"stix\">\u2223<\/span> through the additive and compositional design. More sophisticated combination functions may be used to learn transition functions <span class=\"mathjax-tex\">\\({{\\bf{s}}}^{{\\prime} }={T}_{\\psi }({\\bf{s}},{\\bf{z}})\\)<\/span> that allow sequential latent cell state modeling of subsequently applied perturbations, for example, <span class=\"mathjax-tex\">\\(\\widehat{{\\bf{y}}}={g}_{\\phi }({T}_{\\psi }(\\ldots {T}_{\\psi }({\\bf{s}},{{\\bf{z}}}_{{{\\bf{p}}}_{{\\bf{1}}}})\\ldots ,{{\\bf{z}}}_{{{\\bf{p}}}_{{\\bf{m}}}}))\\)<\/span>. To obtain <b>s<\/b> and <b>z<\/b><sub><i>p<\/i><\/sub>,\u2009<i>p<\/i>\u2009<span class=\"stix\">\u2208<\/span>\u2009<i>P<\/i>, we use learned encoders, namely the basal state model and the perturbation model, which are discussed below.<\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec16\">Basal state model<\/h4>\n<p>The basal state encoder is designed to capture intrinsic cellular attributes, such as cell-cycle stage, cell type and other baseline phenotypic features, by mapping a control cell\u2019s gene expression profile <span class=\"mathjax-tex\">\\({\\bf{x}}\\in {{\\mathbb{R}}}^{n}\\)<\/span> to a compact, low-dimensional embedding, <span class=\"mathjax-tex\">\\({\\bf{s}}={f}_{{\\mathtt{basal}}}({\\bf{x}})\\in {{\\mathbb{R}}}^{d}\\)<\/span>. The basal state model was subject to hyperparameter tuning, with an MLP performing best for unseen and double tasks (Supplementary Note 1 and Supplementary Fig. 7 for cross-cell type). The MLP learns a direct deterministic mapping from high-dimensional input gene expression data to a fixed-size embedding space. The MLP architecture offers a simple and computationally efficient framework for representation learning, while still retaining the capacity to model complex, nonlinear dependencies inherent in gene expression data.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar1\">Basal information matching and aggregation<\/h3>\n<p>An important aspect of modeling the basal state is the alignment of the control with the predicted perturbed cell. Note that experimental protocols can vary widely between different data sources. Therefore, we randomly sample this control according to the same cell line and dataset or experimental protocol. Furthermore, we explore basal state matching, where the control cell is selected to closely match the batch metadata of the perturbed sample. As this matching is not unique, we randomly sample one such appropriate control. Lastly, we experiment with basal state averaging, where instead of a single control, we compute the average expression profile across all matching controls for a given cell line and\/or batch. This produces a more stable estimate of the basal state. Both strategies consistently improved model performance in our experiments.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar2\">Encoders<\/h3>\n<p>Beyond MLP encoders on raw gene expression profiles, we explored multiple transcriptomics foundation model embeddings to obtain a latent representation of the basal state. Specifically, we experiment with scGPT, and scVI pretrained on the Joung dataset<sup>39<\/sup>. We also include a variant with no basal state encoder, where the (latent) basal state space is represented directly by the raw expression profile (that is, <b>s<\/b>: = <b>x<\/b>). In this configuration, the perturbation decoder learns a <i>\u0394<\/i> vector from the perturbation embedding, which is added to the control profile: <span class=\"mathjax-tex\">\\({\\bf{x}}\\in {{\\mathbb{R}}}^{n}\\)<\/span>, that is, <span class=\"mathjax-tex\">\\(\\widehat{{\\bf{y}}}={\\bf{x}}+{g}_{\\phi }({\\sum }_{p\\in P}{{\\bf{z}}}_{p})\\)<\/span>. This resembles the formulation of the general baseline, with a trained model predicting the <i>\u0394<\/i> instead of a hand-crafted heuristic. Unsurprisingly, this variant shines most in settings with limited data availability for learning robust basal state representations, for example, perturbation effect prediction in cell lines without seen perturbations.<\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec17\">Perturbation model<\/h4>\n<p>We rely on GNNs that use biological KGs capturing gene (product)\u2013gene (product) interactions to learn informative embeddings for gene perturbations. The GNN learns a matrix of node embeddings associated with the perturbation tokens <span class=\"mathjax-tex\">\\(\\{1,\\ldots ,N\\}\\mapsto {\\bf{Z}}\\in {{\\mathbb{R}}}^{N\\times d}\\)<\/span>, where <span class=\"mathjax-tex\">\\(N\\in {\\mathbb{N}}\\)<\/span> is the number of perturbations relevant to the investigated task. Each row (or node representation) of this matrix represents the latent encoding of a specific perturbation, that is, <span class=\"mathjax-tex\">\\({{\\bf{z}}}_{p}\\in {{\\mathbb{R}}}^{d}\\)<\/span> required for the latent shift is the <i>p<\/i>th row of <b>Z<\/b>. More specifically, we first associate each perturbation <i>p<\/i>\u2009<span class=\"stix\">\u2208<\/span>\u2009{1, \u2026, <i>N<\/i>} with a randomly initialized input node embedding <span class=\"mathjax-tex\">\\({{\\bf{h}}}_{p}^{0}\\in {{\\mathbb{R}}}^{{d}_{0}},{d}_{0}\\in {\\mathbb{N}}\\)<\/span> that are consolidated in the input node feature matrix <span class=\"mathjax-tex\">\\({{\\bf{H}}}^{0}\\in {{\\mathbb{R}}}^{N\\times {d}_{0}}\\)<\/span>. During training, those node features are (1) treated as model parameters that are learned using backpropagation and (2) subsequently refined through the message passing of an <i>L<\/i>-layer GNN, that is, <span class=\"mathjax-tex\">\\({{\\bf{H}}}^{\\ell }={{\\mathtt{LAYER}}}_{{\\theta }_{\\ell }}({{\\bf{H}}}_{\\ell -1}),0\\le \\ell \\le L\\)<\/span> and <b>Z<\/b>:\u2009=\u2009<b>H<\/b><sup><i>L<\/i><\/sup>. This allows the model to characterize perturbation effects on the basis of known relationships from KGs such as GO<sup>19<\/sup>, STRING<sup>18<\/sup> or proprietary data sources.<\/p>\n<p>Real-world KGs present inherent imperfections; they often contain noisy or incorrect edges (false positives), suffer from missing connections (false negatives) and may originate from diverse sources offering multiple, sometimes conflicting, perspectives. The effective use of such complex data necessitates GNN architectures specifically chosen to address these challenges.<\/p>\n<p>To this end, we select and adapt two fundamental GNN approaches. First, to handle noisy edges, we use attention-based models such as the GAT<sup>40,41<\/sup>. The ability of GAT to dynamically (re)weight neighbor importance provides much needed robustness by effectively downweighting less credible connections, which is a major difference relative to non-attention-based methods such as the simple graph convolution<sup>42<\/sup> used in GEARS. Second, to address graph incompleteness and capture long-range dependencies, we use GTs, specifically Exphormer<sup>16,17<\/sup>. Its capacity for attention beyond immediate graph neighbors allows it to potentially model implicit relationships and bridge structural gaps.<\/p>\n<p>Furthermore, it proved crucial for the presented tasks to use multiple KGs that offer complementary and reinforcing perspectives on the task-related biology. For this, we explore architectures designed for synergistic learning from diverse sources. This includes extending GAT to GAT-Hybrid (allowing for node-level attention weighting of information from different KGs), introducing our provenance-aware Exphormer-MG variant and developing GAT-MLG, a multilayer extension of GAT that uses a supra-adjacency representation to effectively integrate information across multiple biological networks.<\/p>\n<p>In Supplementary Note 3, we provide rigorous details on the relevant graph representation learning used for encoding perturbations, the proposed models and the techniques applied to take advantage of complementary information from multiple KGs. In our experimental setup, the learned embeddings for the perturbed gene(s) from the GNN were extracted and combined with a basal state representation to predict the resulting gene expression profile. This comparative analysis allowed us to investigate how different GNN strategies (attention, flexible connectivity and multigraph fusion) perform when learning from the complexities of biological KGs.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"Sec18\">Training and evaluation framework<\/h3>\n<h3 class=\"c-article__sub-heading\" id=\"FPar3\">Algorithm 1<\/h3>\n<p><b>TxPert training algorithm<\/b><\/p>\n<p><b>Require<\/b>: Pert. cells <span class=\"mathjax-tex\">\\({\\mathcal{Y}}\\)<\/span>, control cells <span class=\"mathjax-tex\">\\({\\mathcal{X}}\\)<\/span>, biological prior graph <i>G<\/i>\u2009=\u2009(<i>V<\/i>, <i>E<\/i>) with perturbations <span class=\"mathjax-tex\">\\({\\mathcal{P}}\\subset V\\)<\/span><\/p>\n<p><b>Ensure<\/b>: Minimize MSE loss between predicted and true perturbed cell measurements<\/p>\n<p>1: Initialize input perturbation embeddings <span class=\"mathjax-tex\">\\({\\{{{\\bf{h}}}_{v}^{0}\\}}_{v\\in V}\\)<\/span> randomly<\/p>\n<p>2: <b>for<\/b> each training step <b>do<\/b><\/p>\n<p>3: Sample a batch of perturbed cell profiles: <span class=\"mathjax-tex\">\\({\\{{{\\bf{y}}}_{i}\\}}_{i=1}^{B}\\subset {\\mathcal{Y}}\\)<\/span><\/p>\n<p>4: Sample corresponding control cells from the same experimental batches: <span class=\"mathjax-tex\">\\({\\{{{\\bf{x}}}_{i}\\}}_{i=1}^{B}\\subset {\\mathcal{X}}\\)<\/span><\/p>\n<p>5: Enrich perturbation embeddings using graph prior: <span class=\"mathjax-tex\">\\({\\{{{\\bf{z}}}_{v}\\}}_{v\\in V}\\leftarrow {\\mathtt{GNN}}(G,{\\{{{\\bf{h}}}_{v}^{0}\\}}_{v\\in V})\\)<\/span><\/p>\n<p>6:\u2003<b>for<\/b> each sample in batch <b>do<\/b><\/p>\n<p>7: Encode control cells into basal latent space: <b>b<\/b><sub><i>i<\/i><\/sub>\u2009\u2190\u2009<span class=\"u-monospace\">MLP<\/span><sub>basal<\/sub>(<b>x<\/b><sub><i>i<\/i><\/sub>)<\/p>\n<p>8: Retrieve perturbations <span class=\"mathjax-tex\">\\({P}_{i}\\subset {\\mathcal{P}}\\)<\/span> associated with target <b>y<\/b><sub><i>i<\/i><\/sub><\/p>\n<p>9:\u2003Combine control and perturbation embeddings: <span class=\"mathjax-tex\">\\({\\widehat{{\\bf{z}}}}_{i}\\leftarrow {\\mathtt{COM}}({{\\bf{b}}}_{i},\\{{{\\bf{z}}}_{p}:p\\in {P}_{i}\\})\\)<\/span><\/p>\n<p>10: Decode to predicted perturbed profile: <span class=\"mathjax-tex\">\\({\\widehat{{\\bf{y}}}}_{i}\\leftarrow {{\\mathtt{MLP}}}_{{\\rm{dec}}}({\\widehat{{\\bf{z}}}}_{i})\\)<\/span><\/p>\n<p>11: Compute loss for each sample: <span class=\"mathjax-tex\">\\({{\\mathcal{L}}}_{i}\\leftarrow {\\mathtt{MSE}}({\\widehat{{\\bf{y}}}}_{i},{{\\bf{y}}}_{i})\\)<\/span><\/p>\n<p>12: <b>end for<\/b><\/p>\n<p>13: Compute total loss over batch: <span class=\"mathjax-tex\">\\({\\mathcal{L}}\\leftarrow {\\sum }_{i=1}^{B}{{\\mathcal{L}}}_{i}\\)<\/span><\/p>\n<p>14: Backpropagate and update model parameters<\/p>\n<p>15: <b>end for<\/b><\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec19\">Data splits<\/h4>\n<p>The data were split into training, validation and test sets through grouping by perturbation such that distinct sets of unseen perturbations were reserved both the validation and test sets with target ratios of 0.5625, 0.1875 and 0.25 for the training, validation and test sets, respectively. Moreover, for the cross-cell-type task, the test set was a reserved cell type with only control cells included during training with a breakdown into seen and unseen perturbations therein. As an exception, for the doubles task on the Norman dataset, predefined splits were loaded from the GEARS setup.<\/p>\n<p>Optimal hyperparameters for each model were selected based on the validation Pearson <i>\u0394<\/i> metric. Only metrics on the test set are reported.<\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec20\">Metric definitions<\/h4>\n<p>All metrics are reported as weighted averages, that is, the mean of the mean across cells subjected to each unique perturbation, unless otherwise specified.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar4\">Expression value representations and <i>\u0394<\/i><br \/>\n                           <\/h3>\n<p>Where not otherwise specified, expression values <span class=\"mathjax-tex\">\\({\\bf{x}}\\in {\\mathcal{X}}\\cup {\\mathcal{Y}}\\)<\/span>, are represented as log1<i>p<\/i>-transformed and library size-normalized counts (with target library size of 4,000); that is, for a raw count <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{\\rm{raw}}}\\in {{\\mathbb{R}}}^{n}\\)<\/span>, we define<\/p>\n<div id=\"Equb\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$${\\bf{x}}:=\\log \\left(1+4000\\cdot \\frac{{{\\bf{x}}}_{\\mathrm{raw}}}{\\parallel {{\\bf{x}}}_{\\mathrm{raw}}{\\parallel }_{1}}\\right).$$<\/span><\/p>\n<\/div>\n<p>The other representation used is a <i>\u0394<\/i> representation, which is centered on batch-matched controls. Specifically, for each perturbed cell expression <span class=\"mathjax-tex\">\\({{\\bf{y}}}_{i}\\in {\\mathcal{Y}}\\)<\/span> with cell line <i>c<\/i> and batch <i>b<\/i>, the expression is transformed to<\/p>\n<div id=\"Equc\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$${\\delta }_{i}:={{\\bf{y}}}_{i}-{\\overline{{\\bf{x}}}}_{c,b},$$<\/span><\/p>\n<\/div>\n<p>where <span class=\"mathjax-tex\">\\({\\overline{x}}_{c,b}\\)<\/span> represents the mean expression of control cells <span class=\"mathjax-tex\">\\({\\bf{x}}\\in {\\mathcal{X}}\\)<\/span> with batch <i>b<\/i> and cell line <i>c<\/i>.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar5\">Pearson <i>\u0394<\/i><br \/>\n                           <\/h3>\n<p>Slightly modified from the metric \u2018Pearson correlation (<i>\u0394<\/i> expression)\u2019 from the GEARS manuscript, Pearson <i>\u0394<\/i> calculates the correlation between predicted and observed log fold change versus batch-matched control mean,<\/p>\n<div id=\"Equd\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$${\\text{Pearson}}\\Delta (p):=\\mathrm{Pearson}({\\widehat{\\delta }}_{p},{\\delta }_{p}),$$<\/span><\/p>\n<\/div>\n<p>where <span class=\"mathjax-tex\">\\({\\widehat{\\delta }}_{p}\\)<\/span> and <i>\u03b4<\/i><sub><i>p<\/i><\/sub> are the batch-matched control centering of the prediction and ground truth, respectively, averaged across replicates of certain perturbation <span class=\"mathjax-tex\">\\(p\\in {\\mathcal{P}}\\)<\/span>. For simplicity, we define this and following metrics for single perturbations <span class=\"mathjax-tex\">\\(p\\in {\\mathcal{P}}\\)<\/span> but note that analogous formulations are appropriate for multiple perturbations <span class=\"mathjax-tex\">\\(P\\subset {\\mathcal{P}}\\)<\/span>. The results across all predicted perturbation effects are then averaged to obtain an overall performance estimate.<\/p>\n<p>Note that, for the GEARS model only, we report the exact \u2018Pearson correlation (<i>\u0394<\/i> expression)\u2019 from the GEARS code base instead. We confirmed that any differences between \u2018Pearson correlation (<i>\u0394<\/i> expression)\u2019 and our \u2018Pearson <i>\u0394<\/i>\u2019 were much smaller in practice than the differences between models.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar6\">Retrieval<\/h3>\n<p>We use two variants of the retrieval rank metric that score a prediction\u2019s similarity to the ground truth not overall but relative to other perturbations. These metrics are the same as rank average from PerturBench<sup>10<\/sup>, except that they focus on similarity with a perfect score of 1, a random score of 0.5 and perfect anticorrelated prediction score of 0:<\/p>\n<div id=\"Eque\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$$\\begin{array}{ll}\\mathrm{Retrieval}: &amp; =\\displaystyle\\frac{1}{N}\\sum _{p\\in {\\mathscr{ \\mathcal P }}}\\,\\mathrm{rank}\\,({\\hat{\\delta }}_{p}),\\\\ \\mathrm{rank}\\,({\\hat{\\delta }}_{p}): &amp; =\\displaystyle\\frac{1}{N-1}\\mathop{\\sum}\\limits_{{q\\in {\\mathscr{ \\mathcal P }}}\\atop {q\\ne p}}{{\\bf{1}}}_{\\{\\mathrm{Pearson}({\\hat{\\delta }}_{p},{\\delta }_{p})\\ge \\mathrm{Pearson}({\\hat{\\delta }}_{p},{\\delta }_{q})\\}}.\\end{array}$$<\/span><\/p>\n<\/div>\n<p>For \u2018normalized\u2019 retrieval, the perturbation count <span class=\"mathjax-tex\">\\(N:=| {\\mathcal{P}}|\\)<\/span> matches the original experiment, whereas, for \u2018fast retrieval\u2019, for computational efficiency, a seeded random reference set of only 100 perturbations is used, with the addition of the query perturbant <i>p<\/i> when not in the reference set (thus, <i>N<\/i>\u2009<span class=\"stix\">\u2208<\/span>\u2009{100, 101}. Similar to Person <i>\u0394<\/i>, we report the averaged performance across all perturbations.<\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec21\">Nonlearned general baseline<\/h4>\n<p>To establish a performance floor, we implement a nonlearned general baseline model that predicts expression profiles using mean values observed in the training data. This baseline uses an additive approach that combines the following:<\/p>\n<ul class=\"u-list-style-bullet\">\n<li>\n<p>The mean test cell type control expression profile<\/p>\n<\/li>\n<li>\n<p>Either the perturbation-specific mean changes (for seen perturbations) or the global perturbation mean (for unseen perturbations)<\/p>\n<\/li>\n<li>\n<p>When multiple cell lines are present in the training set, we either use a weighted average according to the number of samples per cell line or the perturbation-specific mean changes from the most similar cell line (nearest-cell-line baseline). Here, similarity is determined on the basis of mean correlation of shared perturbation <i>\u0394<\/i> values between the test and candidate cell line.<\/p>\n<\/li>\n<\/ul>\n<p>Consider a multiset of training samples <span class=\"mathjax-tex\">\\(\\,\\mathrm{Train}\\,\\subset {{\\mathcal{P}}}_{\\mathrm{train}}\\times {{\\mathcal{C}}}_{\\mathrm{train}}\\times {{\\mathcal{B}}}_{\\mathrm{train}}\\)<\/span> consisting of combinations of perturbation(s), cell line(s) and batch effect(s) with a multiset test defined analogously. Consider a perturbation <i>p<\/i> such that (<i>p<\/i>,\u2009<i>c<\/i><sub><i>p<\/i><\/sub>,\u2009<i>b<\/i><sub><i>p<\/i><\/sub>)\u2009<span class=\"stix\">\u2208<\/span>\u2009test with cell line <i>c<\/i><sub><i>p<\/i><\/sub> and batch effect <i>b<\/i><sub><i>p<\/i><\/sub>. Implicitly, a (<i>c<\/i><sub><i>p<\/i><\/sub>,\u2009<i>b<\/i><sub><i>p<\/i><\/sub>) map is associated with a set of control cell profiles in that context.<\/p>\n<p>If there exists (<i>p<\/i>,\u2009<i>c<\/i>,\u2009<i>b<\/i>)\u2009<span class=\"stix\">\u2208<\/span>\u2009train, we have<\/p>\n<div id=\"Equf\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$$\\begin{array}{ll}{\\widehat{{\\bf{y}}}}_{(p,{c}_{p},{b}_{p})} &amp; =\\\\ &amp; {\\bar{{\\bf{x}}}}_{({c}_{p},{b}_{p})}+\\displaystyle\\frac{1}{| \\{(q,c,b)\\in \\,\\mathrm{Train}:q=p\\}| }\\mathop{\\sum }\\limits_{{(q,c,b)\\in \\,\\mathrm{Train}}\\atop{q=p}}{{\\bf{y}}}_{(p,c,b)}-{\\bar{{\\bf{x}}}}_{(c,b)}.\\end{array}$$<\/span><\/p>\n<\/div>\n<p>Otherwise, we use the global <i>\u0394<\/i> across perturbations observed in the training set, that is,<\/p>\n<div id=\"Equg\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$${\\widehat{{\\bf{y}}}}_{(p,{c}_{p},{b}_{p})}={\\bar{{\\bf{x}}}}_{({c}_{p},{b}_{p})}+\\frac{1}{| \\,\\mathrm{Train}| }\\mathop{\\sum }\\limits_{(q,c,b)\\in \\mathrm{Train}}{{\\bf{y}}}_{(p,c,b)}-{\\bar{{\\bf{x}}}}_{(c,b)}.$$<\/span><\/p>\n<\/div>\n<p>For multiple perturbations, this baseline is implemented to initially attempt to use samples where the exact perturbation configuration is present. Otherwise, the perturbation is split into its components and each component is sequentially added to the test control mean according to the above method, adding a local <i>\u0394<\/i> estimate if available and resorting to a global <i>\u0394<\/i> otherwise.<\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec22\">Experimental reproducibility estimation: split-half validation and sample-based extension<\/h4>\n<p>As Perturb-seq is a destructive assay, we cannot observe the same cell in both perturbed and unperturbed states. This necessitates focusing on distribution means rather than individual cell accuracies. To approximate experimental reproducibility, we first use a split-half validation approach:<\/p>\n<p>For each combination of perturbation(s), cell line context and batch, we apply three operations:<\/p>\n<ol class=\"u-list-style-none\">\n<li>\n                    <span class=\"u-custom-list-number\">1.<\/span><\/p>\n<p>Divide test cells into two roughly equal halves<\/p>\n<\/li>\n<li>\n                    <span class=\"u-custom-list-number\">2.<\/span><\/p>\n<p>Calculate mean expression profiles for each half<\/p>\n<\/li>\n<li>\n                    <span class=\"u-custom-list-number\">3.<\/span><\/p>\n<p>Measure the agreement between these means using various metrics<\/p>\n<\/li>\n<\/ol>\n<p>To account for the randomness in choosing the half-split, we repeat the experiment across multiple seeded runs and report average performance. This provides a performance benchmark analogous to human-level reproducibility, which is called accuracy in other machine learning domains.<\/p>\n<p>Consider the set of expression profiles <span class=\"mathjax-tex\">\\({\\mathcal{S}}\\subset {\\mathcal{Y}}\\)<\/span> for a fixed perturbation cell line context and batch (<i>p<\/i>,\u2009<i>c<\/i>,\u2009<i>b<\/i>) in the test set:<\/p>\n<div id=\"Equh\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$$\\begin{array}{rcl}{{\\mathcal{S}}}^{{\\prime} } &amp; \\subseteq &amp; {\\mathcal{S}}:\\,| {{\\mathcal{S}}}^{{\\prime} }| \\approx | {\\mathcal{S}}| \/2\\\\ {\\bar{{\\mathcal{S}}}}_{1} &amp; = &amp; \\frac{1}{| {{\\mathcal{S}}}^{{\\prime} }| }\\mathop{\\sum }\\limits_{{\\bf{y}}\\in {{\\mathcal{S}}}^{{\\prime} }}{\\bf{y}}\\\\ {\\bar{{\\mathcal{S}}}}_{2} &amp; = &amp; \\frac{1}{| {\\mathcal{S}}\\backslash {{\\mathcal{S}}}^{{\\prime} }| }\\mathop{\\sum }\\limits_{{\\bf{y}}\\in {\\mathcal{S}}\\backslash {{\\mathcal{S}}}^{{\\prime} }}{\\bf{y}}.\\end{array}$$<\/span><\/p>\n<\/div>\n<p>We then report<\/p>\n<div id=\"Equi\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$$\\mathrm{Reproduce}(p,c,b)=\\mathrm{Metric}({\\bar{{\\mathcal{S}}}}_{1},{\\bar{{\\mathcal{S}}}}_{2}),$$<\/span><\/p>\n<\/div>\n<p>where <span class=\"u-monospace\">metric<\/span> represents any of our evaluation metrics, for example, Pearson <i>\u0394<\/i>, <span class=\"u-monospace\">Retrieval<\/span> or <span class=\"u-monospace\">MSE<\/span>. Theoretically, the split-half experimental reproducibility is not expected to establish an upper bound for performance of all models at test time because it operates on a different test set (only using half for prediction and testing, respectively). However, it empirically proves to be useful as a competitive (but still theoretically reachable) mark to beat.<\/p>\n<p>As split-half reproducibility is likely an underestimate because of a reduced (halved) number of replicates, we also introduce a sample-based approach that gives an estimate for the reproducibility of the full-size dataset. Using the original count matrix, we calculate the per-batch probability distribution over genes (multinomial distribution maximum-likelihood estimator) for each perturbation. We then sample these distributions to generate two datasets that have the same number of observations (that is, cells) as the original dataset, but with stochastically resampled counts. These are then log1<i>p<\/i>-normalized and subset to the HVGs in the same way as the original dataset, before calculating the experiment reproducibility as described above. A comparison of split-half and sampled reproducibility is shown in Supplementary Table 4.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"Sec23\">Data<\/h3>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec24\">Perturb-seq data sources<\/h4>\n<p>We demonstrate the efficacy of our approach across a range of datasets, including CRISPRi (gene knockdown) of ~2,000 essential genes in K562 and RPE1 cell lines from a previous study<sup>13<\/sup> (also used in GEARS<sup>4<\/sup>) and similarly designed CRISPRi experiments in Jurkat and HEPG2 cell lines from another previous study<sup>24<\/sup>. Furthermore, we implement the Norman<sup>15<\/sup> dataset with 94 unique single and 110 unique double CRISPRa (gene overexpression) perturbations respectively in the K562 cell line.<\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec25\">Graphs: sourcing and processing<\/h4>\n<p>The graphs used as inductive bias in this work can be classified into two main categories: (1) curated publicly available biological knowledge and (2) large-scale perturbation screens.<\/p>\n<p>The curated graphs from category 1 include the GO graph, first used by GEARS, which is constructed by assigning edges between nodes that have a high Jaccard Index in their GO terms<sup>19<\/sup>, the STRING graph<sup>18<\/sup> and Reactome<sup>43<\/sup>.<\/p>\n<p>Category 2 graphs are generated from large-scale perturbation screens including DepMap<sup>44<\/sup> and Perturb-seq<sup>45<\/sup>. These are extensive datasets linking genetic perturbation to either morphological or transcriptomic outcomes, which can offer particularly crucial insights into cellular responses to stimuli. To translate these experimental screens into graphs, we use derived embeddings to represent the genes and cell lines in a high-dimensional space, allowing for the analysis of relationships and identification of dependencies.<\/p>\n<p>To curate these graphs, we first compute the pairwise similarity score between all combinations of genes. This means that, for each pair of genes (<i>g<\/i><sub><i>i<\/i><\/sub>,\u2009<i>g<\/i><sub><i>j<\/i><\/sub>), we compute the cosine similarity between their (aggregated) embeddings <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{i}}\\)<\/span> and <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{\\!j}}\\)<\/span>. Cosine similarity is computed as follows:<\/p>\n<div id=\"Equj\" class=\"c-article-equation\">\n<p><span class=\"mathjax-tex\">$$\\begin{array}{rcl}\\,{{\\rm{cosine}}\\; {\\rm {similarity}}}\\,({{\\bf{x}}}_{{g}_{i}},{{\\bf{x}}}_{{g}_{\\!j}}) &amp; = &amp; \\frac{{{\\bf{x}}}_{{g}_{i}}\\cdot {{\\bf{x}}}_{{g}_{\\!j}}}{\\parallel {{\\bf{x}}}_{{g}_{i}}\\parallel \\parallel {{\\bf{x}}}_{{g}_{\\!j}}\\parallel }\\\\ &amp; = &amp; \\frac{{\\sum }_{k=1}^{n}{{\\bf{x}}}_{{g}_{i}k}{{\\bf{x}}}_{{g}_{\\!j}k}}{\\sqrt{{\\sum }_{k=1}^{n}{{\\bf{x}}}_{{g}_{i}k}^{2}}\\cdot \\sqrt{{\\sum }_{k=1}^{n}{{\\bf{x}}}_{{g}_{\\!j}k}^{2}}}\\end{array}$$<\/span><\/p>\n<\/div>\n<p>where<\/p>\n<ul class=\"u-list-style-bullet\">\n<li>\n<p><span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{i}}\\cdot {{\\bf{x}}}_{{g}_{\\!j}}\\)<\/span> represents the dot product of the vectors<\/p>\n<\/li>\n<li>\n<p><span class=\"mathjax-tex\">\\(\\parallel {{\\bf{x}}}_{{g}_{i}}\\parallel\\)<\/span> and <span class=\"mathjax-tex\">\\(\\parallel {{\\bf{x}}}_{{g}_{i}}\\parallel\\)<\/span> represent the Euclidean norms (magnitudes) of vectors <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{i}}\\)<\/span> and <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{\\!j}}\\)<\/span>, respectively<\/p>\n<\/li>\n<li>\n<p><span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{i}k}\\)<\/span> and <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{\\!j}k}\\)<\/span> are the individual components of vectors <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{i}k}\\)<\/span> and <span class=\"mathjax-tex\">\\({{\\bf{x}}}_{{g}_{\\!j}k}\\)<\/span>.<\/p>\n<\/li>\n<\/ul>\n<p>These cosine similarities are converted to their absolute values because the difference between highly cosine negative and highly cosine positive does not translate literally to the signed weight of the edge in the graph.<\/p>\n<p>We additionally use proprietary data from internal genome-wide perturbation screens, where we measure the similarity of perturbation effect using both microscopy imaging and transcriptomics in various cell types.<\/p>\n<p>Filtering configurations were optimized empirically. We found that the most performant configuration involved selecting for the top 1% of edges by (absolute) weight for screen-based graphs. For all other graph types, we (additionally) filtered for no more than 20 incoming nodes by target. Edge direction was assigned arbitrarily for undirected edges.<\/p>\n<h4 class=\"c-article__sub-heading c-article__sub-heading--small\" id=\"Sec26\">Data understanding<\/h4>\n<p>Additional methods related to specific analyses are described below.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar7\">Pharos knowledge rank<\/h3>\n<p>The Pharos initiative consolidates a variety of statistics relating to how researched and well known specific genes are<sup>26<\/sup>. Starting from this, we ranked knowledge levels as the mean of the rank of the Pharos Pubmed score and the rank of the Pharos negative log novelty score to create a single Pharos knowledge rank. We used this rank to break down and compare to the performance of models and understand potential bias. The \u2018knowledge levels\u2019 0, 1, 2 and 3 correspond to the following bins of the Pharos knowledge rank:<\/p>\n<ul class=\"u-list-style-bullet\">\n<li>\n<p>knowledge level 0 (least characterized): 0\u20130.2<\/p>\n<\/li>\n<li>\n<p>knowledge level 1: 0.2\u20130.4<\/p>\n<\/li>\n<li>\n<p>knowledge level 2: 0.4\u20130.6<\/p>\n<\/li>\n<li>\n<p>knowledge level 3: 0.6\u20130.8<\/p>\n<\/li>\n<li>\n<p>knowledge level 3 (most characterized): 0.8\u20131.<\/p>\n<\/li>\n<\/ul>\n<h3 class=\"c-article__sub-heading\" id=\"FPar8\">Within versus across<\/h3>\n<p>In investigating the correlations between controls and mean baselines, we compared \u2018within\u2019-context correlation to \u2018across\u2019-context correlation. Generally, before calculating either, all examples were first split into two mutually exclusive halves, A and B, where within-context correlation is a comparison of A versus B, in each context, while across-context correlation is a comparison of an arbitrary half of one context to another context. The only exception is across-batch controls in Fig. 1a, for which, to make a conservative estimate of across-batch variance, full batches were aggregated without splitting. For batch comparison, individual control cells were split and aggregated; for the mean baselines, the <i>\u03b4<\/i> of perturbant replicate cells was preaggregated and then split (such that the halves had nonoverlapping perturbations).<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar9\">Functional enrichment<\/h3>\n<p>To achieve a descriptive biological summary of the actual gene expression changes in the mean baseline, we first calculated a meta mean baseline (mean of all cell types and datasets, using the intersect of provided expressed genes) and defined upregulated and downregulated genes as having a remaining <i>\u03b4<\/i>\u2009&gt;\u20090.05 or <i>\u03b4<\/i>\u2009&lt;\u2009\u22120.05, respectively. We then ran functional enrichment testing separately on each on these sets (versus the background of all genes in the dataset intersect) using the GOATOOLS package<sup>46<\/sup>.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"FPar10\">Metric selection through retrieval<\/h3>\n<p>To avoid confusion and distraction caused by reporting many similarly performing or perhaps slightly contradicting metrics, we first performed an empirical \u2018test of the test metrics\u2019. We adapted the evaluation method pioneered by a previous study<sup>12<\/sup> to work for our data and task. In short, this method uses cross-context retrieval of a perturbation as a way to judge whether a representation and metric together allow the retention and comparison of details necessary to distinguish perturbations. In our case, we modified the method to work on nonaggregated single-cell expression profiles (as this is the input and output of our model) and ran retrieval across the essential perturbation set on core cell types (K562, RPE1, HEPG2 and Jurkat). For each retrieval calculation, instead of aggregating, we first randomly sampled one cell of each perturbation. We ran three replicates on each of the cell\u2009\u00d7\u2009cell pairings for <i>n<\/i>\u2009=\u20093\u2009\u00d7\u20096\u2009=\u200918 total estimates per perturbation. We report the 0.9 quantile, to focus on active perturbants; however, similar patterns were observed at other thresholds.<\/p>\n<p>Note that the choice to focus on single cells excluded use of the representation selected by a previous study<sup>12<\/sup>, namely the signed <i>P<\/i> value. To derisk this, we ran a preliminary analysis on the exact setup described previously<sup>12<\/sup>. We were only able to reproduce their results when making choices that would have limited the extensibility of our data and training setup; in particular, we found the high performance of the signed <i>P<\/i> value to rely on performing a global fit (and, thus, using a global estimate for gene-wise variance) across all contexts for determining differential expression.<\/p>\n<p>We focused our selection of representations and metrics especially common in the perturbational transcriptomics literature but acknowledge the current omission of count-based representation and metrics.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"Sec27\">General<\/h3>\n<p>Model development and analysis were performed with Python 3.12 and PyTorch 2.6.0. Plotting was performed with a combination of seaborn 0.13.2 and Matplotlib 3.10.8. Box-and-whisker plots used seaborn defaults, where the box represents the 0.25\u20130.75 quantiles, and the center line the median. The whiskers extend to the furthest observed data point within 1.5\u00d7 the nearest interquartile range.<\/p>\n<h3 class=\"c-article__sub-heading\" id=\"Sec28\">Reporting summary<\/h3>\n<p>Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.<\/p>\n<\/div>\n\n","protected":false},"excerpt":{"rendered":"<p>TxPert architecture TxPert predicts the transcriptional response \\({\\bf{y}}\\in {\\mathcal{Y}}\\subset {{\\mathbb{R}}}^{n}\\) given a set of perturbation tokens \\(P\\subset {\\mathcal{P}}:=\\{1,\\ldots ,N\\}\\) and a basal state representation derived from&#46;&#46;&#46;<\/p>\n","protected":false},"author":1,"featured_media":411,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"fifu_image_url":"https:\/\/media.springernature.com\/m685\/springer-static\/image\/art%3A10.1038%2Fs41587-026-03113-4\/MediaObjects\/41587_2026_3113_Fig1_HTML.png","fifu_image_alt":"","footnotes":""},"categories":[1],"tags":[431,433,432,430,429,493,428,427,494,495],"class_list":["post-410","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-rj","tag-agriculture","tag-bioinformatics","tag-biomedical-engineering-biotechnology","tag-biomedicine","tag-biotechnology","tag-gene-expression-profiling","tag-general","tag-life-sciences","tag-machine-learning","tag-virtual-drug-screening"],"_links":{"self":[{"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=\/wp\/v2\/posts\/410","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=410"}],"version-history":[{"count":0,"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=\/wp\/v2\/posts\/410\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=\/wp\/v2\/media\/411"}],"wp:attachment":[{"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=410"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=410"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/rjbarrett.redirectme.net\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=410"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}