Offline reinforcement learning has shown promise for solving tasks in safety-critical settings, such as clinical decision support. Its application, however, has been limited by the need for interpretability and interactivity for clinicians.
To address these challenges, we propose medical decision transformer (MeDT), a novel and versatile framework based on the goal-conditioned RL paradigm for sepsis treatment recommendation. MeDT is based on the decision transformer architecture, and conditions the model on expected treatment outcomes, hindsight patient acuity scores, past dosages and the patient's current and past medical state at every timestep.
This allows it to consider the complete context of a patient's medical history, enabling more informed decision-making. By conditioning the policy's generation of actions on user-specified goals at every timestep, MeDT enables clinician interactability while avoiding the problem of sparse rewards. Using data from the MIMIC-III dataset, we show that MeDT produces interventions that outperform or are competitive with existing methods while enabling a more interpretable, personalized and clinician-directed approach.
Recent research in RL is shifting towards attention-based networks like transformers which counter many challenges of RNNs. Chen et al. proposed the Decision Transformer (DT), a transformer based RL policy learning model that has proven effective for offline RL. DT addresses the problem of sparse or distracting rewards by leveraging self-attention for credit assignment, which incorporates contextual information into the learning process. Furthermore, the transformer's ability to model long sequences enables it to consider the patient's history of states and medications to make predictions.
Building on this idea, we propose an offline RL framework where treatment dosage recommendation is framed as a sequence modeling problem. The proposed framework called the Medical Decision Transformer (MeDT), is based on the DT architecture and recommends optimal treatment dosages by autoregressively modeling a patient's state while conditioning on hindsight return information. To provide the policy with more informative and goal-directed input, we also condition MeDT on hindsight patient acuity scores at every time step. This enhances interpretability of the conditioning facilitating interaction of clinicians with the model.
We model our trajectories such that the transformer conditions on future desired returns to generate treatment dosages. Specifically, we use returns-to-go \[r_t=\sum_{t^{\prime}=t}^T R_{t^{\prime}},\]which represents patient death or survival status, to condition the model. In addition, we propose to condition MeDT on future patient acuity scores (SAPS2), or acuity-to-go, where the acuity score provides an indication of the severity of illness of the patient in the ICU, based on the status of the patient's physiological systems. This leads to more information-dense conditioning, allowing clinicians to interact with the model and guide the policy's generation of treatment dosages.
The acuity scores are split to represent different physiological systems, defined as \[k = (k_c, k_r, k_n, k_l, k_h, k_m, k_o)\] which represent the status of the cardiovascular, respiratory, neurological, renal, hepatic, haematologic and other systems, respectively. This enhances the usability of the model for clinicians, enabling efficient interaction with the model for future dosage recommendations, considering the current state of the patient's organs. Using these scores, the treatment progress over T time steps forms a trajectory \[ \tau=\left((r_1, k_1, s_1, a_1), (r_2, k_2, s_2, a_2), \ldots, (r_T, k_T, s_T, a_T)\right)\].
In online RL, policies are assessed by having them interact with the environment. However, healthcare involves patients, where employing this evaluation method is unsafe. As a stand-in for the simulator during inference, we propose to additionally learn an approximate model (state predictor) of \[P_{\theta}(s_{t}|a_{< t},s_{< t})\] with a similar architecture as the policy model. During inference, this model allows autoregressive generation of a sequence of actions by predicting how the patient state evolves as a result of those actions. The following algorithm details this rollout procedure
The learnt MeDT policy learns to recommend dosages similar to that of clinicians.
Dosages output by MeDT lead to better estimated outcomes than baselines.
MeDT leads to more stable patient rollouts than DT.
@inproceedings{rahman2023empowering,
title={Empowering Clinicians with Me{DT}: A Framework for Sepsis Treatment},
author={Abdul Rahman, Aamer and Agarwal, Pranav and Michalski, Vincent and Noumeir, Rita and Jouvet, Philippe and
Ebrahimi Kahou, Samira},
booktitle={NeurIPS 2023 Workshop on Goal-Conditioned Reinforcement Learning},
year={2023}
}