A pre-training model training method, device, equipment, medium and product
By freezing the basic parameters of the pre-trained encoder and employing a momentum contrastive learning framework and a memory queue, low-cost alignment of pathological images and gene expression data was achieved. This solves the problems of high training cost and limited alignment quality in existing technologies, and improves the stability and effectiveness of cross-modal alignment.
Patent Information
- Authority / Receiving Office
- CN · China
- Patent Type
- Applications(China)
- Current Assignee / Owner
- FUDAN UNIVERSITY
- Filing Date
- 2026-04-03
- Publication Date
- 2026-06-19
AI Technical Summary
Existing technologies for aligning pathological images with gene expression data suffer from high training costs, loss of general representation capabilities on single-modal data, and limited cross-modal alignment quality.
The training method of pre-trained models is adopted, which updates only the trainable parameter structure by freezing the basic parameters. Combined with the momentum contrastive learning framework and memory queue, effective alignment of pathological images and gene expression data is achieved.
It reduces training costs, avoids the problem of catastrophic forgetting, improves the quality and stability of cross-modal alignment, and enables effective mapping of pathological images and gene expression data to a unified semantic space.
Smart Images

Figure CN122241234A_ABST
Abstract
Description
Technical Field
[0001] This application relates to the field of computer technology, and more specifically, to a training method and apparatus for a pre-trained model, an electronic device, a computer-readable storage medium, and a computer program product. Background Technology
[0002] In the fields of artificial intelligence and deep learning, multimodal learning aims to process and correlate information from different data sources. For example, in spatial transcriptomics, simultaneously acquiring pathological images of tissue sections and corresponding gene expression data is of great significance for understanding the tissue microenvironment. However, there is a significant "modal gap" between pathological images (visual modality) and gene expression data (genomic modality), meaning that the two differ greatly in data form, feature space, and semantic granularity.
[0003] In related technologies, contrastive learning methods are commonly used to achieve alignment between two modalities. However, this method has the following problems: First, directly fine-tuning the full parameters of a pre-trained large single-modality model (such as a pathological image encoder or a gene expression encoder) requires huge amounts of GPU (Graphics Processing Unit) memory and a long training time, resulting in excessively high training costs. Second, full parameter fine-tuning can easily cause the model to lose the general representational abilities learned on single-modality data, leading to the "catastrophic forgetting" problem. Finally, standard contrastive learning methods are limited by GPU memory, making it difficult to scale up batch sizes, resulting in insufficient negative samples and inadequate contrast signals, thus limiting the quality of cross-modal alignment.
[0004] Therefore, how to achieve effective alignment of pathological images and gene expression data with low training cost while maintaining the capabilities of a single-modal basic model is a technical problem that needs to be solved by those skilled in the art. Summary of the Invention
[0005] The purpose of this application is to provide a training method, apparatus, electronic device, computer-readable storage medium, and computer program product for a pre-trained model, which achieves effective alignment of pathological images and gene expression data with low training cost while maintaining the capabilities of a single-modal basic model.
[0006] To achieve the above objectives, this application provides a training method for a pre-trained model, comprising: Obtain a training dataset; wherein the training dataset includes multiple training samples, and the training samples include pathological images and corresponding gene expression data; Construct a pre-trained model; wherein the pre-trained model includes a query branch and a key branch, the query branch includes a first query sub-branch and a second query sub-branch, the first query sub-branch includes a pathological image encoder, an image projection head and an image prediction head connected in sequence, the second query sub-branch includes a gene expression encoder, a gene projection head and a gene prediction head connected in sequence, the key branch includes a first key sub-branch and a second key branch, the first key branch includes a pathological image encoder and an image projection head connected in sequence, the second key branch includes a gene expression encoder and a gene projection head connected in sequence, the pathological image encoder and the gene expression encoder include trainable parameter structures; The training samples are input into the pre-trained model, and query vectors and key vectors corresponding to the training samples are generated through the query branch and the key branch, respectively; wherein, the query vector includes the pathological image query vector output by the first query sub-branch and the gene expression query vector output by the second query sub-branch, and the key vector includes the pathological image key vector output by the first key sub-branch and the gene expression key vector output by the second key branch; Alignment loss is calculated based on the similarity between the query vector and the key vector set; wherein, the key vector set includes the key vector corresponding to the current training sample and the historical key vectors in the memory queue; the alignment loss is calculated based on a first alignment loss and a second alignment loss, wherein the first alignment loss is calculated based on the similarity between the pathological image query vector and the gene expression key vector, and the second alignment loss is calculated based on the similarity between the gene expression query vector and the pathological image key vector; Based on the alignment loss, the trainable parameters in the query branch are updated through backpropagation, the trainable parameters in the key branch are updated through exponential moving average, and the key vector corresponding to the current training sample is written into the memory queue. The pathological image to be reasoned is input into the trained pathological image encoder to obtain a pathological image representation aligned with the gene expression semantic space; wherein, the pathological image representation is applied to downstream tasks, which include any one of cross-modal retrieval, tissue classification, and gene expression prediction.
[0007] Optionally, after obtaining the training dataset, the method further includes: Data at multiple spatial scales are extracted from the training samples; wherein, the multiple spatial scales include a single-point scale, a first neighborhood scale, and a second neighborhood scale, wherein the second neighborhood scale is larger than the first neighborhood scale; The pathological image encoder is used to encode pathological images at each spatial scale to obtain pathological image embeddings at multiple scales; the gene expression encoder is used to encode gene expression data at each spatial scale to obtain gene expression embeddings at multiple scales. The first query sub-branch further includes a first multi-scale fusion module connected between the pathological image encoder and the image projection head, used to fuse the pathological image embeddings at multiple scales to obtain a fused pathological image embedding; the second query sub-branch further includes a second multi-scale fusion module connected between the gene expression encoder and the gene projection head, used to fuse the gene expression embeddings at multiple scales to obtain a fused gene expression embedding. The first key branch further includes a third multi-scale fusion module connected between the pathological image encoder and the image projection head, used to fuse the pathological image embeddings of the multiple scales to obtain a fused pathological image embedding; the second key branch further includes a fourth multi-scale fusion module connected between the gene expression encoder and the gene projection head, used to fuse the gene expression embeddings of the multiple scales to obtain a fused gene expression embedding.
[0008] Optionally, before inputting the training samples into the pre-trained model, the method further includes: The pathological image is enhanced using a first enhancement strategy to obtain a query pathological image, and the pathological image is enhanced using a second enhancement strategy to obtain a key pathological image; wherein, the enhancement intensity of the first enhancement strategy is higher than that of the second enhancement strategy. Accordingly, the training samples are input into the pre-trained model, and the query vector and key vector corresponding to the training samples are generated through the query branch and the key branch, respectively, including: The pathological image to be queried is input into the first query branch to generate a pathological image query vector, and the gene expression data is input into the second query branch to generate a gene expression query vector. The pathological image of the bond is input into the first bond branch to generate a pathological image bond vector, and the pathological image of the bond is input into the second bond branch to generate a gene expression bond vector.
[0009] Optionally, the gene expression data includes gene identifier sequences and gene expression values; After inputting the training samples into the pre-trained model, the method further includes: The pathological image embedding output by the pathological image encoder is extended along the sequence dimension and spliced with the gene embedding output by the gene name embedding layer in the gene expression encoder to obtain the fusion feature. Gene expression values are predicted based on the fusion features, reconstruction loss is calculated based on the predicted gene expression values and gene expression values in the training samples, and total training loss is calculated based on the alignment loss and the reconstruction loss. Accordingly, updating the parameters of the low-rank adaptation module in the query branch through backpropagation based on the alignment loss includes: The parameters of the low-rank adaptation module in the query branch are updated through backpropagation based on the total training loss.
[0010] Optionally, the step of updating the trainable parameters in the query branch through backpropagation based on the alignment loss, updating the trainable parameters in the key branch through exponential moving average, and writing the key vector corresponding to the current training sample into the memory queue includes: In each forward propagation, the key vector corresponding to the current forward propagation is stored in the buffer; After each preset number of forward propagations, a parameter update is performed. During the parameter update, the trainable parameters in the query branch are updated through backpropagation based on the gradient accumulated during the preset number of forward propagations. Based on the updated trainable parameters in the query branch, the trainable parameters in the key branch are updated using an exponential moving average method; wherein, the update amount of the trainable parameters in the key branch is determined by the weighted difference between the current parameter values of the trainable parameters in the query branch and the trainable parameters in the key branch. Write the key vector corresponding to the preset number of forward propagations stored in the buffer into the memory queue, and then clear the buffer.
[0011] Optionally, the trainable parameter structure is a low-rank adaptation module; The target layer of the pathological image encoder into which the low-rank adaptation module is injected includes the query-key-value joint projection layer and the output projection layer of the attention layer of the visual Transformer of the pathological image encoder. The target layer of the gene expression encoder into which the low-rank adaptation module is injected includes the query projection layer, key projection layer, value projection layer, output projection layer of the gene expression encoder's Transformer, and the linear layer of the feedforward network.
[0012] To achieve the above objectives, this application provides a training apparatus for a pre-trained model, comprising: An acquisition unit is used to acquire a training dataset; wherein the training dataset includes multiple training samples, and the training samples include pathological images and corresponding gene expression data; A construction unit is used to construct a pre-trained model; wherein the pre-trained model includes a query branch and a key branch, the query branch includes a first query sub-branch and a second query sub-branch, the first query sub-branch includes a pathological image encoder, an image projection head and an image prediction head connected in sequence, the second query sub-branch includes a gene expression encoder, a gene projection head and a gene prediction head connected in sequence, the key branch includes a first key sub-branch and a second key branch, the first key branch includes a pathological image encoder and an image projection head connected in sequence, the second key branch includes a gene expression encoder and a gene projection head connected in sequence, the pathological image encoder and the gene expression encoder include trainable parameter structures; The input unit is used to input the training samples into the pre-trained model and generate the query vector and key vector corresponding to the training samples through the query branch and the key branch, respectively; wherein, the query vector includes the pathological image query vector output by the first query sub-branch and the gene expression query vector output by the second query sub-branch, and the key vector includes the pathological image key vector output by the first key sub-branch and the gene expression key vector output by the second key branch; The first calculation unit is used to calculate an alignment loss based on the similarity between the query vector and the key vector set; wherein, the key vector set includes the key vector corresponding to the current training sample and the historical key vectors in the memory queue; the alignment loss is calculated based on a first alignment loss and a second alignment loss, wherein the first alignment loss is calculated based on the similarity between the pathological image query vector and the gene expression key vector, and the second alignment loss is calculated based on the similarity between the gene expression query vector and the pathological image key vector; The update unit is used to update the trainable parameters in the query branch through backpropagation based on the alignment loss, update the trainable parameters in the key branch through exponential moving average, and write the key vector corresponding to the current training sample into the memory queue. The inference unit is used to input the pathological image to be inferred into the trained pathological image encoder to obtain a pathological image representation aligned with the gene expression semantic space; wherein, the pathological image representation is applied to downstream tasks, the downstream tasks including any one of cross-modal retrieval, tissue classification, and gene expression prediction.
[0013] To achieve the above objectives, this application provides an electronic device, comprising: Memory, used to store computer programs; A processor for executing the computer program to implement the steps of the training method for the pre-trained model as described above.
[0014] To achieve the above objectives, this application provides a computer-readable storage medium storing a computer program that, when executed by a processor, implements the steps of the training method for the pre-trained model described above.
[0015] To achieve the above objectives, this application provides a computer program product, including a computer program that, when executed by a processor, implements the steps of the training method for the pre-trained model described above.
[0016] The training method for the pre-trained model provided in this application significantly reduces the total number of parameters to be trained by freezing the basic parameters of the pre-trained encoder and updating only the injected trainable parameter structure. This efficient parameter fine-tuning method significantly reduces the GPU memory usage during training and shortens the training time, thus achieving the technical effect of training the model at a lower cost. Simultaneously, because the basic model parameters are frozen, the general representational capabilities learned on single-modal data are fully preserved, effectively avoiding the catastrophic forgetting problem caused by full parameter fine-tuning. Furthermore, this application employs a momentum contrastive learning framework, maintaining a memory queue to store the key vectors of historical training samples, enabling the acquisition of negative samples far exceeding the current batch size when calculating the alignment loss. This enriches the number of negative samples without increasing the GPU memory burden, making the alignment signal more sufficient and thus improving the quality and stability of cross-modal alignment. Through the optimization of the alignment loss, pathological images and gene expression data can be mapped to a unified semantic space, achieving effective alignment between the two modalities. This application also discloses a training device for a pre-trained model, an electronic device, a computer-readable storage medium, and a computer program product, which can achieve the same technical effects as described above.
[0017] It should be understood that the above general description and the following detailed description are merely exemplary and do not limit this application. Attached Figure Description
[0018] To more clearly illustrate the technical solutions in the embodiments of this application or the prior art, the drawings used in the description of the embodiments or the prior art will be briefly introduced below. Obviously, the drawings described below are only some embodiments of this application. For those skilled in the art, other drawings can be obtained based on these drawings without creative effort. The drawings are used to provide a further understanding of this disclosure and constitute a part of the specification. They are used together with the following detailed description to explain this disclosure, but do not constitute a limitation of this disclosure. In the drawings: Figure 1 This is a flowchart illustrating a training method for a pre-trained model according to an exemplary embodiment; Figure 2 This is a schematic diagram illustrating a gradient accumulation and queue update collaborative mechanism according to an exemplary embodiment; Figure 3 A flowchart illustrating another training method for a pre-trained model according to an exemplary embodiment; Figure 4 This is a schematic diagram illustrating a momentum contrast learning framework and a dual-pyramid fusion structure according to an exemplary embodiment; Figure 5 This is a schematic diagram illustrating a multi-scale data construction and dual-view enhancement according to an exemplary embodiment; Figure 6 This is a schematic diagram illustrating a conditional reconstruction head structure according to an exemplary embodiment; Figure 7 This is a structural diagram illustrating a training apparatus for a pre-trained model according to an exemplary embodiment; Figure 8 This is a structural diagram of an electronic device according to an exemplary embodiment. Detailed Implementation
[0019] The technical solutions of the embodiments of this application will be clearly and completely described below with reference to the accompanying drawings. Obviously, the described embodiments are only some embodiments of this application, and not all embodiments. Based on the embodiments of this application, all other embodiments obtained by those of ordinary skill in the art without creative effort are within the protection scope of this application.
[0020] It should be noted that, in the description of this application, the terms "comprising," "including," or any other variations thereof are intended to cover non-exclusive inclusion, such that a process, method, article, or apparatus that comprises a list of elements includes not only those elements but also other elements not expressly listed, or elements inherent to such a process, method, article, or apparatus. The terms "first," "second," etc., in this application are used to distinguish similar objects and are not used to describe a specific order or sequence.
[0021] To enable those skilled in the art to better understand the present application, the present application will be further described in detail below with reference to the accompanying drawings and specific embodiments.
[0022] This application discloses a training method for a pre-trained model, which achieves effective alignment of pathological images and gene expression data with low training cost while maintaining the capabilities of a single-modality basic model.
[0023] See Figure 1 A flowchart illustrating a training method for a pre-trained model according to an exemplary embodiment is shown below. Figure 1 As shown, it includes: S101: Obtain the training dataset; wherein, the training dataset includes multiple training samples, and the training samples include pathological images and corresponding gene expression data; In this step, the training dataset is sourced from a spatial transcriptomics platform, such as 10x Visium. Each training sample corresponds to a spatial spot on a tissue slice. Pathological images are hematoxylin-eosin stained tissue slice micrographs, normalized to 3×112×112 pixels (RGB), with pixel values normalized to the range of 0 to 1. Gene expression data consists of gene identifier sequences and corresponding discretized expression values after high-variance gene (HVG) screening, with sequence lengths not exceeding 1501 (including CLS-specific markers). Each sample may also include pre-computed embedding vectors, such as 1536-dimensional pathological image embeddings extracted by a pre-trained pathological image encoder and 512-dimensional gene expression embeddings extracted by a pre-trained gene expression encoder. Data can be stored in WebDataset format, packaged as tar fragments, with each sample stored in NPZ format and accompanied by metadata files recording tissue type, species, and other information to support conditional filtering. During data loading, gene sequences of different lengths can be uniformly padded to a preset maximum length (e.g., 1501). Gene identifiers can be padded with markers, gene expression values can be padded with -2.0, and attention masks can be generated to mark valid locations. The training and validation sets can be deterministically divided based on the hash value of the sample keys, for example, the training set accounts for 95%, and shard-level shuffling and sample-level shuffling can be enabled to increase data randomness. The sample-level shuffling buffer size can be 10000.
[0024] S102: Construct a pre-trained model; wherein the pre-trained model includes a query branch and a key branch, the query branch includes a first query sub-branch and a second query sub-branch, the first query sub-branch includes a pathological image encoder, an image projection head and an image prediction head connected in sequence, the second query sub-branch includes a gene expression encoder, a gene projection head and a gene prediction head connected in sequence, the key branch includes a first key sub-branch and a second key branch, the first key branch includes a pathological image encoder and an image projection head connected in sequence, the second key branch includes a gene expression encoder and a gene projection head connected in sequence, the pathological image encoder and the gene expression encoder include trainable parameter structures; In this step, the pre-trained model is built based on a momentum contrastive learning framework, which provides a large number of consistent negative samples by maintaining a slowly updated key branch. The query branch encodes the current training sample and updates parameters via gradient descent; the key branch copies parameters from the query branch using an exponential moving average and does not participate in gradient calculation, thus ensuring the stability and consistency of the key representation. Both the pathology image encoder and the gene expression encoder use pre-trained unimodal base models with their basic parameters frozen, and only some parameters are trainable. The pathology image encoder uses a pre-trained model based on the Vision Transformer (ViT) architecture, with input being 3×112×112 RGB pathology image patches and output being a 1536-dimensional embedding vector. All basic parameters are frozen after loading the pre-trained weights. The gene expression encoder employs a pre-trained model based on the Transformer architecture. The input consists of gene identifier sequences and their corresponding discretized expression values. These are encoded separately through gene name embedding layers and gene value embedding layers, then summed and processed by the Transformer encoder. The output at the CLS marker position is used as a 512-dimensional cell embedding, which is then L2 normalized before being output. All basic parameters are frozen after loading pre-trained weights. The projection head maps the high-dimensional embedding from the encoder output to a low-dimensional contrastive learning space, while the prediction head exists only in the query branch to increase the transformation capability of the query branch and prevent representation collapse.
[0025] As a feasible implementation method, the trainable parameter structure is a low-rank adaptation module. The low-rank adaptation module is a parameter-efficient fine-tuning method. Its core idea is to add a bypass network composed of low-rank matrices to the bypass of the original pre-trained weights, and achieve the adaptation of the pre-trained model by training only this bypass network.
[0026] As a feasible implementation, the target layer of the pathological image encoder into which the low-rank adaptation module is injected includes the query-key-value joint projection layer and the output projection layer of the attention layer of the visual Transformer of the pathological image encoder; the target layer of the gene expression encoder into which the low-rank adaptation module is injected includes the query projection layer, key projection layer, value projection layer, output projection layer of the Transformer of the gene expression encoder, and the linear layer of the feedforward network.
[0027] In practical implementation, by deploying the low-rank adaptation module in the query-key-value joint projection layer and feedforward network layer of the Transformer architecture, with a rank of 16, a scaling factor of 16, and a dropout rate of 0.1, the most effective adjustment of model behavior can be achieved with the fewest trainable parameters. In the visual Transformer, the query-key-value joint projection layer of the attention layer is responsible for generating query, key, and value vectors for calculating attention weights, and the output projection layer is responsible for integrating the outputs of multi-head attention. In the gene expression encoder Transformer, the joint weight matrix of the standard multi-head attention module is split into independent query, key, and value projection layers to allow for independent adaptation of each projection layer. The feedforward network layer is responsible for performing non-linear transformations on the attention output. The low-rank adaptation module is injected into the query projection layer, key projection layer, value projection layer, output projection layer, and two linear layers of the feedforward network, with the same parameters as the pathological image encoder. Adapting these core layers can effectively guide the model to focus on features related to cross-modal alignment while maintaining the model's generality in other aspects.
[0028] In another embodiment, the low-rank adaptation module is replaced with other parameter-efficient fine-tuning methods. For example, the low-rank adaptation can be replaced with learnable prompts (Prompt Tuning), which concatenates learnable prompt vectors before the encoder input; or it can be replaced with an adapter layer, which inserts a bottleneck-structured adaptation layer after the feedforward network of each layer of the Transformer.
[0029] As a feasible implementation, the projection head adopts a structure of linear layers, ReLU activation function layers, and linear layers connected in sequence. The pathological image projection head maps 1536-dimensional input to 128-dimensional output; the gene expression projection head maps 512-dimensional input to 128-dimensional output. The prediction head exists only in the query branch, and its structure consists of a linear layer (128-dimensional input to 256-dimensional output), a batch normalization layer (BatchNorm1d), a ReLU activation function layer, and a linear layer (256-dimensional input to 128-dimensional output) connected in sequence. The batch normalization layer provides implicit negative sample effects to prevent representation collapse.
[0030] S103: Input the training sample into the pre-trained model, and generate the query vector and key vector corresponding to the training sample through the query branch and the key branch respectively; wherein, the query vector includes the pathological image query vector output by the first query sub-branch and the gene expression query vector output by the second query sub-branch, and the key vector includes the pathological image key vector output by the first key sub-branch and the gene expression key vector output by the second key branch; In this step, training samples are fed into the query branch and key branch for forward propagation, respectively. Specifically, the pathological image is first fed into the pathological image encoder in the query branch. After passing through the low-rank adaptation module and the encoder's base network, a high-dimensional pathological image embedding is output. This embedding is then passed sequentially through the image projection head and image prediction head, and L2 normalization is performed before output, ultimately yielding a 128-dimensional pathological image query vector. Similarly, gene expression data is fed into the gene expression encoder. After processing by the low-rank adaptation module and encoder, it is then passed through the gene projection head and gene prediction head to obtain a normalized gene expression query vector. For the key branch, the processing flow is similar to that of the query branch, but the key branch does not include a prediction head. Therefore, after the pathological image and gene expression data pass through their respective encoders and projection heads, the normalized pathological image key vector and gene expression key vector are directly obtained.
[0031] S104: Calculate alignment loss based on the similarity between the query vector and the key vector set; wherein, the key vector set includes the key vector corresponding to the current training sample and the historical key vectors in the memory queue; the alignment loss is calculated based on a first alignment loss and a second alignment loss, wherein the first alignment loss is calculated based on the similarity between the pathological image query vector and the gene expression key vector, and the second alignment loss is calculated based on the similarity between the gene expression query vector and the pathological image key vector; In this step, the goal of contrastive learning is to bring closer different modal representations of the same training sample (i.e., positive sample pairs) and push away modal representations of different samples (i.e., negative sample pairs). As a feasible implementation, the alignment loss in this step can be a contrastive loss. Specifically, the dot product similarity between the pathological image query vector and all gene expression key vectors is calculated and scaled by a temperature coefficient (e.g., τ=0.07) to obtain the logits matrix of the first contrastive loss. The set of gene expression key vectors here includes the gene expression key vectors corresponding to all samples in the current batch, as well as historical gene expression key vectors obtained from the memory queue. Similarly, the similarity between the gene expression query vector and all pathological image key vectors is calculated to obtain the logits matrix of the second contrastive loss. The memory queue is a first-in, first-out queue used to store the key vectors of historical training samples, thereby greatly expanding the number of negative samples without increasing the batch size, making the contrast signal more comprehensive. The label is the index within the current batch, indicating the position of the positive sample pair in the logits matrix. The contrastive loss can be taken as the average of the cross-entropy loss of the first and second contrastive losses: L_nce=(CrossEntropy(logits_p2g,labels)+CrossEntropy(logits_g2p,labels)) / 2, where L_nce is the contrastive loss, CrossEntropy(logits_p2g,labels) is the first contrastive loss, CrossEntropy(logits_g2p,labels) is the second contrastive loss, logits_p2g is the logits matrix of the first contrastive loss, logits_g2p is the logits matrix of the second contrastive loss, and labels represent the positive sample labels in contrastive learning.
[0032] In another embodiment, the contrastive loss is replaced with other forms of alignment loss. For example, knowledge distillation loss is used, which performs soft label distillation with the output distribution of the teacher model as the objective; or ranking loss (Triplet Loss) is used, which optimizes the distance between positive sample pairs as being less than the distance between negative sample pairs.
[0033] In another embodiment, a cross-modal consistency regularization term is introduced, adding a characterization smoothing constraint to the contrastive loss. For example, a consistency loss (such as mean squared error or cosine similarity loss) is applied to the encoder outputs of different augmented views of the same sample, improving robustness to staining noise and gene expression sparsity.
[0034] S105: Based on the alignment loss, update the trainable parameters in the query branch through backpropagation, update the trainable parameters in the key branch through exponential moving average, and write the key vector corresponding to the current training sample into the memory queue.
[0035] In this step, the gradient is calculated using the backpropagation algorithm based on the calculated contrastive loss. Only the trainable parameters of the low-rank adaptation modules in the query branch are updated, while the encoder's base parameters remain frozen. The parameters of the key branch are not updated directly via gradient descent, but rather using an exponential moving average. That is, the new parameters of the key branch are obtained by a weighted average of the current parameters of the key branch and the updated parameters of the query branch, making the parameter updates of the key branch smoother and more stable. Simultaneously, the key vectors corresponding to the current training samples (including pathological image key vectors and gene expression key vectors) are written into a memory queue for use in subsequent training steps to calculate the contrastive loss. If the memory queue is full, the oldest key vector is evicted according to the first-in, first-out principle, thus ensuring that the queue stores recent and representative negative samples.
[0036] As a feasible implementation, the step of updating the trainable parameters in the query branch through backpropagation based on the alignment loss, updating the trainable parameters in the key branch through exponential moving average, and writing the key vector corresponding to the current training sample into the memory queue includes: storing the key vector corresponding to the current forward propagation in a buffer during each forward propagation; performing a parameter update after each preset number of forward propagations, updating the trainable parameters in the query branch through backpropagation based on the gradient accumulated during the preset number of forward propagations; updating the trainable parameters in the key branch through exponential moving average based on the updated trainable parameters in the query branch; wherein the update amount of the trainable parameters in the key branch is determined by the weighted difference between the current parameter values of the trainable parameters in the query branch and the trainable parameters in the key branch; writing the key vector corresponding to the preset number of forward propagations stored in the buffer into the memory queue, and clearing the buffer.
[0037] This implementation provides a collaborative mechanism for gradient accumulation and queue updates. Due to hardware memory limitations, the batch size typically cannot be set very large. Figure 2As shown, the gradient accumulation strategy allows the model to accumulate gradients over multiple mini-batch forward propagations, and then perform a single parameter update on the query branch at the end of the accumulation period, thus simulating a larger batch size. In this process, the key vectors generated by the forward propagation of each mini-batch are not immediately pushed into the memory queue, but are first stored in a buffer. When the accumulation period ends and the parameter update is completed, the parameters of the key branch are updated using an exponential moving average, and then all the key vectors from the mini-batch in the buffer are pushed into the memory queue at once, and the buffer is cleared. At the end of each training epoch, the buffer is additionally cleared to prevent accumulation across epochs. This design ensures that the key vectors in the memory queue are consistent with the data distribution corresponding to the gradient used for the current parameter update, avoiding the problem of inconsistent negative samples caused by asynchronous parameter updates and queue updates, thus improving training stability and convergence efficiency.
[0038] In practice, the AdamW optimizer can be used with an initial learning rate of 1×10⁻⁶. -4 The weight decays to 1×10 -4 β = (0.9, 0.999). The learning rate scheduling uses a cosine annealing strategy, with a minimum learning rate of 1 × 10⁻⁶. -8 The maximum number of training epochs is 100. The gradient accumulation step is 1, which can be adjusted according to GPU memory usage. The gradient clipping threshold is 10.0. The training accuracy is BF16. Gradient checkpointing is enabled for both the pathology image encoder and the gene expression encoder to reduce GPU memory usage.
[0039] Furthermore, all parameters of the key branch do not participate in gradient calculation and are updated using the exponential moving average (EMA): θ_k = m·θ_k + (1-m)·θ_q, where θ_k is the trainable parameter in the key branch, θ_q is the trainable parameter in the query branch, and the momentum coefficient m = 0.999. Memo queues of size 32768×128 are initialized for the pathological image modality and gene expression modality, respectively, initialized with random vectors and L2 normalized, and updated using a first-in-first-out strategy.
[0040] In one embodiment, an early stopping strategy is employed to monitor the validation set loss, with a patience value of 50 training epochs. During the validation phase, the total loss calculation does not use weighting; that is, the validation loss is a direct sum of the contrastive loss and the reconstruction loss, used to evaluate the model's generalization ability.
[0041] S106: Input the pathological image to be reasoned into the trained pathological image encoder to obtain a pathological image representation aligned with the gene expression semantic space; wherein, the pathological image representation is applied to downstream tasks, the downstream tasks including any one of cross-modal retrieval, tissue classification, and gene expression prediction.
[0042] In practice, after training, the parameters in the trained parameter structure are merged into the base model weights to obtain the aligned pathological image encoder. During inference, the pathological image to be analyzed is input into the aligned pathological image encoder to obtain a 1536-dimensional representation vector aligned with the gene expression semantic space. This representation vector can be used for the following downstream tasks: cross-modal retrieval (i.e., given a pathological image, retrieving the semantically most similar gene expression profile from a gene expression embedding library); tissue classification (i.e., using the aligned representation to train a linear classifier for tissue type classification); and gene expression prediction (i.e., predicting gene expression values from a pathological image using a conditional reconstruction head).
[0043] The training method for the pre-trained model provided in this application significantly reduces the total number of parameters to be trained by freezing the basic parameters of the pre-trained encoder and updating only the injected trainable parameter structure. This efficient parameter fine-tuning method significantly reduces the GPU memory usage during training and shortens the training time, thereby achieving the technical effect of training the model at a lower cost. Simultaneously, because the basic model parameters are frozen, the general representational capabilities learned on single-modal data are fully preserved, effectively avoiding the catastrophic forgetting problem caused by full parameter fine-tuning. Furthermore, this application employs a momentum contrastive learning framework, maintaining a memory queue to store the key vectors of historical training samples, enabling the acquisition of negative samples far exceeding the current batch size when calculating the alignment loss. This enriches the number of negative samples without increasing the GPU memory burden, making the alignment signal more sufficient and thus improving the quality and stability of cross-modal alignment. Through the optimization of the alignment loss, pathological images and gene expression data can be mapped to a unified semantic space, achieving effective alignment between the two modalities.
[0044] This application discloses a training method for a pre-trained model. Compared to the previous embodiment, this embodiment further explains and optimizes the technical solution. Specifically: See Figure 2 A flowchart illustrating another training method for a pre-trained model according to an exemplary embodiment, such as... Figure 2 As shown, it includes: S201: Obtain the training dataset; wherein, the training dataset includes multiple training samples, and the training samples include pathological images and corresponding gene expression data; S202: Extract data at multiple spatial scales from the training samples; wherein the multiple spatial scales include a single-point scale, a first neighborhood scale, and a second neighborhood scale, and the second neighborhood scale is larger than the first neighborhood scale; In this step, considering the inherent multi-scale spatial structure of spatial transcriptomics data, multi-scale data is constructed for each spatial locus. The single-point scale consists of the pathological image patch and gene expression data of that spatial locus itself. The first neighborhood scale aggregates data from the current spatial locus and several of its nearest neighbors (e.g., 7 neighboring loci), representing the contextual information of a local neighborhood. The second neighborhood scale aggregates a larger range of neighboring loci (e.g., 19 neighboring loci), representing broader tissue microenvironment information. This multi-scale construction approach can capture biological features at different spatial granularities, from fine to coarse. Each scale contains independent pathological image patches, gene marker sequences, and gene expression values, providing a foundation for subsequent multi-scale fusion.
[0045] S203: Construct a pre-trained model; wherein the pre-trained model includes a query branch and a key branch. The query branch includes a first query sub-branch and a second query sub-branch. The first query sub-branch includes a pathological image encoder carrying a low-rank adaptation module, a first multi-scale fusion module, an image projection head, and an image prediction head connected in sequence. The first multi-scale fusion module is used to fuse the pathological image embeddings of the multiple scales to obtain the fused pathological image embedding. The second query sub-branch includes a gene expression encoder carrying a low-rank adaptation module, a second multi-scale fusion module, a gene projection head, and a gene prediction head connected in sequence. The key branch includes a first key sub-branch and a second key branch. The first key branch includes a pathological image encoder carrying a low-rank adaptation module, a third multi-scale fusion module, a second multi-scale fusion module, and a third multi-scale fusion module connected in sequence. The module and image projection head, the second key branch includes a gene expression encoder carrying a low-rank adaptation module, a fourth multi-scale fusion module and a gene projection head connected in sequence, the low-rank adaptation module is a trainable parameter matrix, the pathological image encoder is used to encode pathological images at each spatial scale to obtain pathological image embeddings at multiple scales, the gene expression encoder is used to encode gene expression data at each spatial scale to obtain gene expression embeddings at multiple scales, the first multi-scale fusion module and the second multi-scale fusion module are used to fuse the gene expression embeddings at multiple scales to obtain fused gene expression embeddings, the third multi-scale fusion module and the fourth multi-scale fusion module are used to fuse the gene expression embeddings at multiple scales to obtain fused gene expression embeddings; In this step, such as Figure 4As shown, the pre-trained model introduced a multi-scale fusion module based on the previous embodiment. Each encoder (pathological image encoder and gene expression encoder) processes input data at multiple scales, generating embedding vectors at corresponding scales. The multi-scale fusion module integrates these embedding vectors at different scales into a unified fused embedding. For example, for a spatial site, the pathological image encoder outputs pathological image embeddings at a single-point scale, a first neighborhood scale, and a second neighborhood scale, respectively. Then, the first multi-scale fusion module fuses these embeddings into a single fused pathological image embedding. Similarly, the same operation is performed on the gene expression module. This structure allows the model to adaptively learn the contribution weights of features at different scales to the final alignment task, thereby fully utilizing multi-scale spatial structure information.
[0046] The multi-scale fusion module contains learnable scalar weight parameters equal to the number of spatial scales. For example, three weight parameters correspond to the single-point scale, the first neighborhood scale, and the second neighborhood scale, respectively. These are initialized to the reciprocal of the number of scales, i.e., 1 / 3, and after softmax normalization, the embedding vectors at each scale are weighted and summed. It should be noted that the query branch and the key branch each maintain an independent set of multi-scale fusion modules, allowing the query and key ends to learn different fusion strategies.
[0047] In another embodiment, the learnable weighted summation of the multi-scale fusion module is replaced with an attention mechanism fusion. For example, a cross-attention module is used, with embeddings at one scale as queries and embeddings at other scales as keys and values, and multi-scale information is adaptively fused through attention weights.
[0048] S204: The pathological image is enhanced using a first enhancement strategy to obtain a query pathological image, and the pathological image is enhanced using a second enhancement strategy to obtain a key pathological image; wherein, the enhancement intensity of the first enhancement strategy is higher than that of the second enhancement strategy; In this step, such as Figure 5As shown, to improve the model's robustness to staining differences and geometric variations in pathological images, an asymmetric data augmentation strategy was employed for the input images of the query and key branches. For the query branch, a strong augmentation strategy, i.e., the first augmentation strategy, was used. This strategy may include color space dithering and geometric transformation. One implementation of color space dithering is to convert the image from RGB space to optical density space, then convert it to the Hematoxylin-Eosin-DAB (HED) color space through a color deconvolution matrix, apply Gaussian random perturbations to the hematoxylin and eosin channels to simulate staining differences, and finally inversely transform it back to RGB space. Geometric transformation may include random horizontal flipping (probability 0.5), random vertical flipping (probability 0.5), random rotation angle selection from 0°, 90°, 180°, and 270°, and random scaling and cropping (scaling range 0.8 to 1.0). For the key branch, only a weak augmentation strategy, i.e., the second augmentation strategy, is employed. This includes only random horizontal flips (probability 0.5) and random vertical flips (probability 0.5) in the geometric transformation, without applying rotation, scaling, or color dithering. This asymmetric augmentation strategy forces the model to learn invariant representations between the strongly augmented view of the query branch and the weakly augmented view of the key branch, thereby enhancing the model's ability to generalize to real-world image variations. Images of the same sample at multiple spatial scales share the same augmentation parameters to ensure consistency of multi-scale data across spatial transformations.
[0049] S205: Input the queried pathological image into the first query branch to generate a pathological image query vector, input the gene expression data into the second query branch to generate a gene expression query vector; input the key pathological image into the first key branch to generate a pathological image key vector, input the key pathological image into the second key branch to generate a gene expression key vector; In this step, the enhanced multi-scale data is fed into the pre-trained model. In the query branch, the strongly enhanced multi-scale pathological images are encoded by a pathological image encoder to obtain pathological image embeddings at each scale. These embeddings are then fused into a 1536-dimensional pathological image embedding by the first multi-scale fusion module. This embedding is then passed through an image projection head and an image prediction head, followed by L2 normalization, ultimately yielding a 128-dimensional pathological image query vector. Gene expression data is directly (without enhancement) fed into the second query branch, encoded by a gene expression encoder, fused by the second multi-scale fusion module, and then normalized by the projection head and prediction head to obtain a 128-dimensional gene expression query vector. In the key branch, the weakly enhanced multi-scale pathological images are encoded by a pathological image encoder, fused by the third multi-scale fusion module, and then normalized by the projection head to obtain a 128-dimensional pathological image key vector. Since the key branch does not require a prediction head, the processing flow is simpler. The generation of gene expression key vectors is similar to that of pathological image key vectors, obtained by L2 normalization after passing through the fourth multi-scale fusion module and the projection head.
[0050] S206: Calculate the contrast loss based on the similarity between the query vector and the key vector set; wherein, the key vector set includes the key vector corresponding to the current training sample and the historical key vectors in the memory queue; the contrast loss is calculated based on a first contrast loss and a second contrast loss, wherein the first contrast loss is calculated based on the similarity between the pathological image query vector and the gene expression key vector, and the second contrast loss is calculated based on the similarity between the gene expression query vector and the pathological image key vector; S207: Extend the pathological image embedding output by the pathological image encoder along the sequence dimension, and concatenate it with the gene embedding output by the gene name embedding layer in the gene expression encoder to obtain a fusion feature. Predict the gene expression value based on the fusion feature, and calculate the reconstruction loss based on the predicted gene expression value and the gene expression value in the training sample. In this step, a reconstruction aid task is introduced simultaneously with the calculation of the contrastive loss. This task aims to reconstruct the original gene expression values from multi-scale pathological image embeddings, thereby preserving fine-grained information within the modality during alignment. Specifically, the multi-scale pathological image embeddings are first average-pooled to obtain a global image embedding. Then, using the gene name embedding layer in the frozen gene expression encoder, the gene identifier sequence is encoded into the corresponding gene embedding vector. The global image embedding is extended along the sequence dimension to match the gene embedding sequence in sequence length, and then concatenated to obtain a fused feature. This fused feature is passed through a conditional reconstruction head composed of multilayer perceptrons, ultimately outputting the predicted expression value of each gene. The reconstruction loss uses Masked Mean Squared Error (MSE), which calculates the mean squared error between the predicted and true values only for genes at non-zero and non-padded positions in the training samples. This reconstruction loss forces the model to learn the mapping from image to gene expression, providing a stronger supervisory signal for contrastive learning.
[0051] In specific implementation, such as Figure 6 As shown, the processing flow of the conditional reconstruction head is as follows: average pooling is performed on the pathological image embeddings at three scales to obtain a 1536-dimensional global image embedding; the gene identifier sequence is encoded into a 512-dimensional gene embedding sequence using the gene name embedding layer of the frozen gene expression encoder; the global image embedding is extended along the sequence dimension and then concatenated with the gene embedding to obtain a 2048-dimensional fused input; it is processed by a two-layer fused multilayer perceptron (2048-dimensional input to 256-dimensional output, 256-dimensional input to 256-dimensional output, each layer containing ReLU activation function and LayerNorm normalization); and the log1p expression value of each gene is predicted by the output linear layer (256-dimensional input to 1-dimensional output).
[0052] The reconstruction loss is calculated only for genes at non-zero and non-filled positions using the mean squared error mask: L_recon = MSE(predicted[non_zero_mask], target[non_zero_mask]), where L_recon is the reconstruction loss, MSE is the mean squared error loss function, predicted is the predicted gene expression value, non_zero_mask represents the non-zero gene mask, and target is the true gene expression value. predicted and target can be normalized values of the predicted and true gene expression values divided by 50.0.
[0053] In another embodiment, the reconstruction loss is assigned different weights for different spatial scales, for example, the single-point scale weight is 0.00023, the first neighborhood scale weight is 0.00012, and the second neighborhood scale weight is 0.000007, to reflect the difference in the contribution of information at different scales to the reconstruction.
[0054] S208: Calculate the total training loss based on the contrast loss and the reconstruction loss, update the parameters of the low-rank adaptation module in the query branch through backpropagation based on the total training loss, update the parameters of the low-rank adaptation module in the key branch through exponential moving average, and write the key vector corresponding to the current training sample into the memory queue. In this step, the total training loss is the weighted sum of the contrastive loss and the reconstruction loss, that is: the total training loss is the weighted sum of the contrastive loss and the reconstruction loss. L = λ_nce·L_nce + λ_recon·L_recon; Where L is the total training loss, L_nce is the contrastive loss, L_recon is the reconstruction loss, λ_nce is the contrastive loss weight, and λ_recon is the reconstruction loss weight. For example, the contrastive loss weight is 1.0 and the reconstruction loss weight is 0.3.
[0055] By jointly optimizing these two losses, the model can align different modal representations to a unified space while maintaining sufficient information in the pathological image encoder output for predicting gene expression, thereby improving the information richness of the aligned representation. The parameter update method is similar to that described in the previous embodiment: the parameters of the low-rank adaptation module of the query branch are updated through backpropagation, the parameters of the key branch are updated through exponential moving average, and the key vector is written to the memory queue through a gradient accumulation and queue update collaborative mechanism. Through this joint training, a representation vector that encodes the pathological image into a representation aligned with the gene expression semantic space is finally obtained.
[0056] S209: Input the pathological image to be reasoned into the trained pathological image encoder to obtain a pathological image representation aligned with the gene expression semantic space; wherein, the pathological image representation is applied to downstream tasks, the downstream tasks including any one of cross-modal retrieval, tissue classification, and gene expression prediction.
[0057] Therefore, this embodiment reduces GPU memory usage and training time by injecting low-rank adaptation modules into the pre-trained pathological image encoder and gene expression encoder and freezing the basic model parameters, training only a small number of adaptation parameters, while avoiding catastrophe forgetting. This embodiment utilizes the memory queue mechanism of the momentum contrastive learning framework to expand the number of negative samples to a large scale without increasing GPU memory, ensuring sufficient contrast signals and improving cross-modal alignment quality. This embodiment uses dual-pyramid multi-scale fusion to allow the pathological image end and gene expression end to learn fusion weights independently, thereby fully utilizing the multi-scale spatial structure information of spatial transcriptomics data. This embodiment uses joint training of contrastive loss and conditional reconstruction loss to preserve fine-grained information within the modality while performing cross-modal alignment, thereby improving the information richness of the alignment representation and its applicability to downstream tasks. This embodiment uses a dual-view asymmetric enhancement strategy based on the hematoxylin-eosin-DAB color space to simulate staining differences in pathological images, thereby improving the robustness of contrastive learning to staining variations.
[0058] Based on the above embodiments, a two-stage training strategy is adopted as a feasible implementation method.
[0059] In practice, the two-stage training strategy includes: Stage 1 is the reconstruction pre-training stage, where only the reconstruction branch is trained, and the contrastive loss is not calculated. Its purpose is to allow the low-rank adaptation module to learn the basic mapping relationship from image to gene, providing better initialization for subsequent contrastive learning. During this stage, the key branch and the query branch are kept synchronized using exponential moving average. Stage 2 is the joint contrastive and reconstruction training stage, where both the contrastive learning branch and the reconstruction branch are trained simultaneously. The total loss is the weighted sum of the contrastive loss and the reconstruction loss. Stage switching is controlled by configuration parameters, typically ranging from 0 to 20 training epochs, with a default value of 0, meaning joint training begins from the 0th training epoch.
[0060] In one embodiment, the key parameters in the above embodiments can be adjusted within the following ranges: the rank of the low-rank adaptation ranges from 4 to 64, the scaling factor ranges from 4 to 64, the dropout rate ranges from 0.0 to 0.3; the projection space dimension ranges from 64 to 512; the temperature coefficient ranges from 0.01 to 0.5; the momentum coefficient ranges from 0.99 to 0.9999; the memory queue size ranges from 4096 to 65536; the batch size ranges from 8 to 128; the contrast loss weight ranges from 0.5 to 2.0, the reconstruction loss weight ranges from 0.1 to 1.0; and the learning rate ranges from 1 × 10⁻⁶. -5 Up to 1×10 - ³, the minimum learning rate ranges from 1×10. -8 Up to 1×10 -6The maximum number of training rounds ranges from 50 to 500; the gradient clipping threshold ranges from 1.0 to 20.0; the spatial scale can be adjusted by the number of neighborhoods; and the HED perturbation intensity ranges from 0.01 to 0.1.
[0061] The following describes a training device for a pre-trained model provided in an embodiment of this application. The training device for a pre-trained model described below and the training method for a pre-trained model described above can be referred to each other.
[0062] See Figure 7 A structural diagram of a training device for a pre-trained model, as shown in an exemplary embodiment, is provided. Figure 7 As shown, it includes: The acquisition unit 100 is used to acquire a training dataset; wherein, the training dataset includes multiple training samples, and the training samples include pathological images and corresponding gene expression data; The construction unit 200 is used to construct a pre-trained model; wherein the pre-trained model includes a query branch and a key branch, the query branch includes a first query sub-branch and a second query sub-branch, the first query sub-branch includes a pathological image encoder, an image projection head and an image prediction head connected in sequence, the second query sub-branch includes a gene expression encoder, a gene projection head and a gene prediction head connected in sequence, the key branch includes a first key sub-branch and a second key branch, the first key branch includes a pathological image encoder and an image projection head connected in sequence, the second key branch includes a gene expression encoder and a gene projection head connected in sequence, the pathological image encoder and the gene expression encoder include trainable parameter structures; The input unit 300 is used to input the training sample into the pre-trained model and generate the query vector and key vector corresponding to the training sample through the query branch and the key branch respectively; wherein, the query vector includes the pathological image query vector output by the first query sub-branch and the gene expression query vector output by the second query sub-branch, and the key vector includes the pathological image key vector output by the first key sub-branch and the gene expression key vector output by the second key branch; The first calculation unit 400 is used to calculate an alignment loss based on the similarity between the query vector and the key vector set; wherein, the key vector set includes the key vector corresponding to the current training sample and the historical key vectors in the memory queue; the alignment loss is calculated based on a first alignment loss and a second alignment loss, wherein the first alignment loss is calculated based on the similarity between the pathological image query vector and the gene expression key vector, and the second alignment loss is calculated based on the similarity between the gene expression query vector and the pathological image key vector; The update unit 500 is used to update the trainable parameters in the query branch through backpropagation based on the alignment loss, update the trainable parameters in the key branch through exponential moving average, and write the key vector corresponding to the current training sample into the memory queue. The inference unit 600 is used to input the pathological image to be inferred into the trained pathological image encoder to obtain a pathological image representation aligned with the gene expression semantic space; wherein, the pathological image representation is applied to downstream tasks, the downstream tasks including any one of cross-modal retrieval, tissue classification, and gene expression prediction.
[0063] The training apparatus for the pre-trained model provided in this application significantly reduces the total number of parameters to be trained by freezing the basic parameters of the pre-trained encoder and updating only the injected trainable parameter structure. This efficient parameter fine-tuning method significantly reduces the GPU memory usage during training and shortens the training time, thereby achieving the technical effect of training the model at a lower cost. Simultaneously, because the basic model parameters are frozen, the general representational capabilities learned on single-modal data are fully preserved, effectively avoiding the catastrophic forgetting problem caused by full parameter fine-tuning. Furthermore, this application employs a momentum contrastive learning framework, maintaining a memory queue to store the key vectors of historical training samples, enabling the acquisition of negative samples far exceeding the current batch size when calculating the alignment loss. This enriches the number of negative samples without increasing the GPU memory burden, making the alignment signal more sufficient and thus improving the quality and stability of cross-modal alignment. Through the optimization of the alignment loss, pathological images and gene expression data can be mapped to a unified semantic space, achieving effective alignment between the two modalities.
[0064] Based on the above embodiments, as a preferred embodiment, it further includes: An extraction unit is used to extract data at multiple spatial scales from the training samples; wherein the multiple spatial scales include a single-point scale, a first neighborhood scale, and a second neighborhood scale, and the second neighborhood scale is larger than the first neighborhood scale. The pathological image encoder is used to encode pathological images at each spatial scale to obtain pathological image embeddings at multiple scales; the gene expression encoder is used to encode gene expression data at each spatial scale to obtain gene expression embeddings at multiple scales. The first query sub-branch further includes a first multi-scale fusion module connected between the pathological image encoder and the image projection head, used to fuse the pathological image embeddings at multiple scales to obtain a fused pathological image embedding; the second query sub-branch further includes a second multi-scale fusion module connected between the gene expression encoder and the gene projection head, used to fuse the gene expression embeddings at multiple scales to obtain a fused gene expression embedding. The first key branch further includes a third multi-scale fusion module connected between the pathological image encoder and the image projection head, used to fuse the pathological image embeddings of the multiple scales to obtain a fused pathological image embedding; the second key branch further includes a fourth multi-scale fusion module connected between the gene expression encoder and the gene projection head, used to fuse the gene expression embeddings of the multiple scales to obtain a fused gene expression embedding.
[0065] Based on the above embodiments, as a preferred embodiment, it further includes: An enhancement processing unit is configured to enhance the pathological image using a first enhancement strategy to obtain a query pathological image, and enhance the pathological image using a second enhancement strategy to obtain a key pathological image; wherein the enhancement intensity of the first enhancement strategy is higher than that of the second enhancement strategy. Accordingly, the input unit 300 is specifically used to: input the queried pathological image into the first query branch to generate a pathological image query vector, input the gene expression data into the second query branch to generate a gene expression query vector; input the key pathological image into the first key branch to generate a pathological image key vector, and input the key pathological image into the second key branch to generate a gene expression key vector.
[0066] Based on the above embodiments, as a preferred embodiment, the gene expression data includes gene identifier sequences and gene expression values; The device further includes: The second computing unit is used to extend the pathological image embedding output by the pathological image encoder along the sequence dimension and concatenate it with the gene embedding output by the gene name embedding layer in the gene expression encoder to obtain a fusion feature; predict gene expression values based on the fusion feature; calculate reconstruction loss based on the predicted gene expression values and gene expression values in the training samples; and calculate the total training loss based on the alignment loss and the reconstruction loss. Accordingly, the update unit 500 is specifically used to update the trainable parameters in the query branch based on the total training loss through backpropagation.
[0067] Based on the above embodiments, as a preferred implementation, the update unit 500 is specifically used for: storing the key vector corresponding to the current forward propagation in a buffer during each forward propagation; performing a parameter update after each preset number of forward propagations, and updating the trainable parameters in the query branch through backpropagation based on the gradient accumulated during the preset number of forward propagations; updating the trainable parameters in the key branch using an exponential moving average method based on the updated trainable parameters in the query branch; wherein the update amount of the trainable parameters in the key branch is determined by the weighted difference between the current parameter values of the trainable parameters in the query branch and the trainable parameters in the key branch; writing the key vector corresponding to the preset number of forward propagations stored in the buffer into the memory queue, and clearing the buffer.
[0068] Based on the above embodiments, as a preferred implementation, the trainable parameter structure is a low-rank adaptation module; the target layer of the pathological image encoder into which the low-rank adaptation module is injected includes the query-key-value joint projection layer and the output projection layer of the attention layer of the visual Transformer of the pathological image encoder; the target layer of the gene expression encoder into which the low-rank adaptation module is injected includes the query projection layer, key projection layer, value projection layer, output projection layer of the Transformer of the gene expression encoder, and the linear layer of the feedforward network.
[0069] Regarding the apparatus in the above embodiments, the specific manner in which each module performs its operation has been described in detail in the embodiments related to the method, and will not be elaborated upon here.
[0070] Based on the hardware implementation of the above program modules, and in order to implement the method of the embodiments of this application, the embodiments of this application also provide an electronic device. Figure 8 This is a structural diagram of an electronic device according to an exemplary embodiment, such as... Figure 8 As shown, the electronic device includes: Communication interface 1 enables information exchange with other devices, such as network devices; Processor 2 is connected to communication interface 1 to enable information interaction with other devices. When running a computer program, it executes the training method of the pre-trained model provided by one or more of the above-mentioned technical solutions. The computer program is stored in memory 3.
[0071] Of course, in practical applications, the various components in an electronic device are coupled together through bus system 4. It can be understood that bus system 4 is used to achieve communication and connection between these components. In addition to the data bus, bus system 4 also includes a power bus, a control bus, and a status signal bus. However, for clarity, in... Figure 8The general will label all buses as Bus System 4.
[0072] The memory 3 in this embodiment is used to store various types of data to support the operation of the electronic device. Examples of such data include any computer program used to operate on the electronic device.
[0073] It is understood that memory 3 can be volatile memory or non-volatile memory, or both. Non-volatile memory can be read-only memory (ROM), programmable read-only memory (PROM), erasable programmable read-only memory (EPROM), electrically erasable programmable read-only memory (EEPROM), ferromagnetic random access memory (FRAM), flash memory, magnetic surface memory, optical disc, or compact disc read-only memory (CD-ROM); magnetic surface memory can be disk storage or magnetic tape storage. Volatile memory can be random access memory (RAM), which is used as an external cache. By way of example, but not limitation, many forms of RAM are available, such as Static Random Access Memory (SRAM), Synchronous Static Random Access Memory (SSRAM), Dynamic Random Access Memory (DRAM), Synchronous Dynamic Random Access Memory (SDRAM), Double Data Rate Synchronous Dynamic Random Access Memory (DDRSDRAM), Enhanced Synchronous Dynamic Random Access Memory (ESDRAM), SyncLink Dynamic Random Access Memory (SLDRAM), and Direct Rambus Random Access Memory (DRRAM).The memory 3 described in the embodiments of this application is intended to include, but is not limited to, these and any other suitable types of memory.
[0074] The methods disclosed in the embodiments of this application can be applied to processor 2, or implemented by processor 2. Processor 2 may be an integrated circuit chip with signal processing capabilities. In the implementation process, each step of the above method can be completed by the integrated logic circuit of the hardware in processor 2 or by instructions in the form of software. The processor 2 may be a general-purpose processor, DSP, or other programmable logic devices, discrete gate or transistor logic devices, discrete hardware components, etc. Processor 2 can implement or execute the methods, steps and logic block diagrams disclosed in the embodiments of this application. The general-purpose processor may be a microprocessor or any conventional processor, etc. The steps of the methods disclosed in the embodiments of this application can be directly manifested as being executed by a hardware decoding processor, or being executed by a combination of hardware and software modules in the decoding processor. The software modules may be located in a storage medium, which is located in memory 3. Processor 2 reads the program in memory 3 and completes the steps of the aforementioned method in combination with its hardware.
[0075] When processor 2 executes the program, it implements the corresponding processes in the various methods of the embodiments of this application. For the sake of brevity, these will not be described in detail here.
[0076] In an exemplary embodiment, this application also provides a storage medium, namely a computer storage medium, specifically a computer-readable storage medium, such as a memory 3 that stores a computer program, which can be executed by a processor 2 to complete the steps described in the aforementioned method. The computer-readable storage medium may be a memory such as FRAM, ROM, PROM, EPROM, EEPROM, Flash Memory, magnetic surface memory, optical disc, or CD-ROM.
[0077] Those skilled in the art will understand that all or part of the steps of the above method embodiments can be implemented by hardware related to program instructions. The aforementioned program can be stored in a computer-readable storage medium. When the program is executed, it performs the steps of the above method embodiments. The aforementioned storage medium includes various media that can store program code, such as mobile storage devices, ROM, RAM, magnetic disks, or optical disks.
[0078] Alternatively, if the integrated units described above are implemented as software functional modules and sold or used as independent products, they can also be stored in a computer-readable storage medium. Based on this understanding, the technical solutions of the embodiments of this application, or the parts that contribute to the prior art, can be embodied in the form of a software product. This computer software product is stored in a storage medium and includes several instructions to cause an electronic device (which may be a personal computer, server, network device, etc.) to execute all or part of the methods described in the various embodiments of this application. The aforementioned storage medium includes various media capable of storing program code, such as mobile storage devices, ROM, RAM, magnetic disks, or optical disks.
[0079] The above description is merely a specific embodiment of this application, but the scope of protection of this application is not limited thereto. Any changes or substitutions that can be easily conceived by those skilled in the art within the scope of the technology disclosed in this application should be included within the scope of protection of this application.
Claims
1. A training method for a pre-trained model, characterized in that, include: Obtain a training dataset; wherein the training dataset includes multiple training samples, and the training samples include pathological images and corresponding gene expression data; Construct a pre-trained model; wherein the pre-trained model includes a query branch and a key branch, the query branch includes a first query sub-branch and a second query sub-branch, the first query sub-branch includes a pathological image encoder, an image projection head and an image prediction head connected in sequence, the second query sub-branch includes a gene expression encoder, a gene projection head and a gene prediction head connected in sequence, the key branch includes a first key sub-branch and a second key branch, the first key branch includes a pathological image encoder and an image projection head connected in sequence, the second key branch includes a gene expression encoder and a gene projection head connected in sequence, the pathological image encoder and the gene expression encoder include trainable parameter structures; The training samples are input into the pre-trained model, and query vectors and key vectors corresponding to the training samples are generated through the query branch and the key branch, respectively; wherein, the query vector includes the pathological image query vector output by the first query sub-branch and the gene expression query vector output by the second query sub-branch, and the key vector includes the pathological image key vector output by the first key sub-branch and the gene expression key vector output by the second key branch; Alignment loss is calculated based on the similarity between the query vector and the key vector set; wherein, the key vector set includes the key vector corresponding to the current training sample and the historical key vectors in the memory queue; the alignment loss is calculated based on a first alignment loss and a second alignment loss, wherein the first alignment loss is calculated based on the similarity between the pathological image query vector and the gene expression key vector, and the second alignment loss is calculated based on the similarity between the gene expression query vector and the pathological image key vector; Based on the alignment loss, the trainable parameters in the query branch are updated through backpropagation, the trainable parameters in the key branch are updated through exponential moving average, and the key vector corresponding to the current training sample is written into the memory queue. The pathological image to be reasoned is input into the trained pathological image encoder to obtain a pathological image representation aligned with the gene expression semantic space; wherein, the pathological image representation is applied to downstream tasks, which include any one of cross-modal retrieval, tissue classification, and gene expression prediction.
2. The training method for the pre-trained model according to claim 1, characterized in that, After obtaining the training dataset, the process also includes: Data at multiple spatial scales are extracted from the training samples; wherein, the multiple spatial scales include a single-point scale, a first neighborhood scale, and a second neighborhood scale, wherein the second neighborhood scale is larger than the first neighborhood scale; The pathological image encoder is used to encode pathological images at each spatial scale to obtain pathological image embeddings at multiple scales; the gene expression encoder is used to encode gene expression data at each spatial scale to obtain gene expression embeddings at multiple scales. The first query sub-branch further includes a first multi-scale fusion module connected between the pathological image encoder and the image projection head, used to fuse the pathological image embeddings at multiple scales to obtain a fused pathological image embedding; the second query sub-branch further includes a second multi-scale fusion module connected between the gene expression encoder and the gene projection head, used to fuse the gene expression embeddings at multiple scales to obtain a fused gene expression embedding. The first key branch further includes a third multi-scale fusion module connected between the pathological image encoder and the image projection head, used to fuse the pathological image embeddings of the multiple scales to obtain a fused pathological image embedding; the second key branch further includes a fourth multi-scale fusion module connected between the gene expression encoder and the gene projection head, used to fuse the gene expression embeddings of the multiple scales to obtain a fused gene expression embedding.
3. The training method for the pre-trained model according to claim 1, characterized in that, Before inputting the training samples into the pre-trained model, the method further includes: The pathological image is enhanced using a first enhancement strategy to obtain a query pathological image, and the pathological image is enhanced using a second enhancement strategy to obtain a key pathological image; wherein, the enhancement intensity of the first enhancement strategy is higher than that of the second enhancement strategy. Accordingly, the training samples are input into the pre-trained model, and the query vector and key vector corresponding to the training samples are generated through the query branch and the key branch, respectively, including: The pathological image to be queried is input into the first query branch to generate a pathological image query vector, and the gene expression data is input into the second query branch to generate a gene expression query vector. The pathological image of the bond is input into the first bond branch to generate a pathological image bond vector, and the pathological image of the bond is input into the second bond branch to generate a gene expression bond vector.
4. The training method for the pre-trained model according to claim 1, characterized in that, The gene expression data includes gene identifier sequences and gene expression values; After inputting the training samples into the pre-trained model, the method further includes: The pathological image embedding output by the pathological image encoder is extended along the sequence dimension and spliced with the gene embedding output by the gene name embedding layer in the gene expression encoder to obtain the fusion feature. Gene expression values are predicted based on the fusion features, reconstruction loss is calculated based on the predicted gene expression values and gene expression values in the training samples, and total training loss is calculated based on the alignment loss and the reconstruction loss. Accordingly, updating the trainable parameters in the query branch based on the alignment loss via backpropagation includes: The trainable parameters in the query branch are updated via backpropagation based on the total training loss.
5. The training method for the pre-trained model according to claim 1, characterized in that, The step of updating the trainable parameters in the query branch through backpropagation based on the alignment loss, updating the trainable parameters in the key branch through exponential moving average, and writing the key vector corresponding to the current training sample into the memory queue includes: In each forward propagation, the key vector corresponding to the current forward propagation is stored in the buffer; After each preset number of forward propagations, a parameter update is performed. During the parameter update, the trainable parameters in the query branch are updated through backpropagation based on the gradient accumulated during the preset number of forward propagations. Based on the updated trainable parameters in the query branch, the trainable parameters in the key branch are updated using an exponential moving average method; wherein, the update amount of the trainable parameters in the key branch is determined by the weighted difference between the current parameter values of the trainable parameters in the query branch and the trainable parameters in the key branch. Write the key vector corresponding to the preset number of forward propagations stored in the buffer into the memory queue, and then clear the buffer.
6. The training method for the pre-trained model according to claim 1, characterized in that, The trainable parameter structure is a low-rank adaptation module; The target layer of the pathological image encoder into which the low-rank adaptation module is injected includes the query-key-value joint projection layer and the output projection layer of the attention layer of the visual Transformer of the pathological image encoder. The target layer of the gene expression encoder into which the low-rank adaptation module is injected includes the query projection layer, key projection layer, value projection layer, output projection layer of the gene expression encoder's Transformer, and the linear layer of the feedforward network.
7. A training device for a pre-trained model, characterized in that, include: An acquisition unit is used to acquire a training dataset; wherein the training dataset includes multiple training samples, and the training samples include pathological images and corresponding gene expression data; A construction unit is used to construct a pre-trained model; wherein the pre-trained model includes a query branch and a key branch, the query branch includes a first query sub-branch and a second query sub-branch, the first query sub-branch includes a pathological image encoder, an image projection head and an image prediction head connected in sequence, the second query sub-branch includes a gene expression encoder, a gene projection head and a gene prediction head connected in sequence, the key branch includes a first key sub-branch and a second key branch, the first key branch includes a pathological image encoder and an image projection head connected in sequence, the second key branch includes a gene expression encoder and a gene projection head connected in sequence, the pathological image encoder and the gene expression encoder include trainable parameter structures; The input unit is used to input the training samples into the pre-trained model and generate the query vector and key vector corresponding to the training samples through the query branch and the key branch, respectively; wherein, the query vector includes the pathological image query vector output by the first query sub-branch and the gene expression query vector output by the second query sub-branch, and the key vector includes the pathological image key vector output by the first key sub-branch and the gene expression key vector output by the second key branch; The first calculation unit is used to calculate an alignment loss based on the similarity between the query vector and the key vector set; wherein, the key vector set includes the key vector corresponding to the current training sample and the historical key vectors in the memory queue; the alignment loss is calculated based on a first alignment loss and a second alignment loss, wherein the first alignment loss is calculated based on the similarity between the pathological image query vector and the gene expression key vector, and the second alignment loss is calculated based on the similarity between the gene expression query vector and the pathological image key vector; The update unit is used to update the trainable parameters in the query branch through backpropagation based on the alignment loss, update the trainable parameters in the key branch through exponential moving average, and write the key vector corresponding to the current training sample into the memory queue. The inference unit is used to input the pathological image to be inferred into the trained pathological image encoder to obtain a pathological image representation aligned with the gene expression semantic space; wherein, the pathological image representation is applied to downstream tasks, the downstream tasks including any one of cross-modal retrieval, tissue classification, and gene expression prediction.
8. An electronic device, characterized in that, include: Memory, used to store computer programs; A processor for executing the computer program to implement the steps of the training method for the pre-trained model as described in any one of claims 1 to 6.
9. A computer-readable storage medium, characterized in that, The computer-readable storage medium stores a computer program that, when executed, implements the steps of the training method for the pre-trained model as described in any one of claims 1 to 6.
10. A computer program product, characterized in that, Includes a computer program, which, when executed, implements the steps of the training method for the pre-trained model as described in any one of claims 1 to 6.