evergreen/osrf/
worker.rs

1use crate::osrf::addr::BusAddress;
2use crate::osrf::app;
3use crate::osrf::client::{Client, ClientSingleton};
4use crate::osrf::conf;
5use crate::osrf::logging::Logger;
6use crate::osrf::message;
7use crate::osrf::message::Message;
8use crate::osrf::message::MessageStatus;
9use crate::osrf::message::MessageType;
10use crate::osrf::message::Payload;
11use crate::osrf::message::TransportMessage;
12use crate::osrf::method::ParamCount;
13use crate::osrf::server::Server;
14use crate::osrf::session::ServerSession;
15use crate::util;
16use crate::EgResult;
17use mptc::signals::SignalTracker;
18use std::cell::RefMut;
19use std::fmt;
20use std::sync::mpsc;
21
22// How often each worker wakes to check for shutdown signals, etc.
23const IDLE_WAKE_TIME: u64 = 5;
24
25/// Each worker thread is in one of these states.
26#[derive(Debug, PartialEq, Copy, Clone)]
27pub enum WorkerState {
28    Idle,
29    Active,
30    Exiting,
31}
32
33#[derive(Debug)]
34pub struct WorkerStateEvent {
35    pub worker_id: u64,
36    pub state: WorkerState,
37}
38
39impl WorkerStateEvent {
40    pub fn worker_id(&self) -> u64 {
41        self.worker_id
42    }
43    pub fn state(&self) -> WorkerState {
44        self.state
45    }
46}
47
48/// A Worker runs in its own thread and responds to API requests.
49pub struct Worker {
50    service: String,
51
52    /// Watches for signals
53    sig_tracker: SignalTracker,
54
55    client: Client,
56
57    /// True if the caller has requested a stateful conversation.
58    connected: bool,
59
60    max_requests: usize,
61
62    keepalive: usize,
63
64    /// Currently active session.
65    /// A worker can only have one active session at a time.
66    /// For stateless requests, each new thread results in a new session.
67    /// Starting a new thread/session in a stateful conversation
68    /// results in an error.
69    session: Option<ServerSession>,
70
71    /// Unique ID for tracking/logging each working.
72    worker_id: u64,
73
74    /// Channel for sending worker state info to our parent.
75    to_parent_tx: mpsc::SyncSender<WorkerStateEvent>,
76}
77
78impl fmt::Display for Worker {
79    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80        write!(f, "Worker ({})", self.worker_id)
81    }
82}
83
84impl Worker {
85    pub fn new(
86        service: String,
87        worker_id: u64,
88        sig_tracker: SignalTracker,
89        to_parent_tx: mpsc::SyncSender<WorkerStateEvent>,
90        max_requests: usize,
91        keepalive: usize,
92    ) -> EgResult<Worker> {
93        let client = Client::connect()?;
94
95        Ok(Worker {
96            sig_tracker,
97            service,
98            worker_id,
99            client,
100            to_parent_tx,
101            max_requests,
102            keepalive,
103            session: None,
104            connected: false,
105        })
106    }
107
108    /// Mutable Ref to our under-the-covers client singleton.
109    fn client_internal_mut(&self) -> RefMut<ClientSingleton> {
110        self.client.singleton().borrow_mut()
111    }
112
113    /// Current session
114    ///
115    /// Panics of session on None.
116    fn session(&self) -> &ServerSession {
117        self.session.as_ref().unwrap()
118    }
119
120    fn session_mut(&mut self) -> &mut ServerSession {
121        self.session.as_mut().unwrap()
122    }
123
124    pub fn worker_id(&self) -> u64 {
125        self.worker_id
126    }
127
128    /// Wait for and process inbound API calls.
129    ///
130    /// Exits with Err if this worker encounters a fatal error and should be purged.
131    pub fn listen(&mut self, factory: app::ApplicationWorkerFactory) -> EgResult<()> {
132        let selfstr = format!("{self}");
133
134        let mut app_worker = (factory)();
135
136        app_worker.worker_start(self.client.clone())?;
137
138        let mut request_count: usize = 0;
139
140        // We listen for API calls at an addressed scoped to our
141        // username and domain.
142        let username = self.client.address().username();
143        let domain = self.client.address().domain();
144
145        let service_addr = BusAddress::for_service(username, domain, &self.service);
146        let service_addr = service_addr.as_str().to_string();
147
148        let my_addr = self.client.address().as_str().to_string();
149
150        while request_count < self.max_requests {
151            let timeout: u64;
152            let sent_to: &str;
153
154            if self.connected {
155                // We're in the middle of a stateful conversation.
156                // Listen for messages sent specifically to our bus
157                // address and only wait up to keeplive seconds for
158                // subsequent messages.
159                sent_to = &my_addr;
160                timeout = self.keepalive as u64;
161            } else {
162                // If we are not within a stateful conversation, clear
163                // our bus data and message backlogs since any remaining
164                // data is no longer relevant.
165                self.reset()?;
166
167                sent_to = &service_addr;
168                timeout = IDLE_WAKE_TIME;
169            }
170
171            // work_occurred will be true if we handled a message or
172            // had to address a stateful session timeout.
173            let (work_occurred, msg_handled) =
174                match self.handle_recv(&mut app_worker, timeout, sent_to) {
175                    Ok(w) => w,
176                    Err(e) => {
177                        log::error!("Error in main loop error: {e}");
178                        break;
179                    }
180                };
181
182            // If we are connected, we remain Active and avoid counting
183            // subsequent requests within this stateful converstation
184            // toward our overall request count.
185            if self.connected {
186                continue;
187            }
188
189            if work_occurred {
190                // also true if msg_handled
191
192                // If we performed any work and we are outside of a
193                // keepalive loop, let our worker know a stateless
194                // request or stateful conversation has just completed.
195                if let Err(e) = app_worker.end_session() {
196                    log::error!("end_session() returned an error: {e}");
197                    break;
198                }
199
200                self.set_idle()?;
201
202                if msg_handled {
203                    // Increment our message handled count.
204                    // Each connected session counts as 1 "request".
205                    request_count += 1;
206
207                    // An inbound message may have modified our
208                    // thread-scoped locale.  Reset our locale back
209                    // to the default so the previous locale does not
210                    // affect future messages.
211                    message::reset_thread_locale();
212                }
213            } else {
214                // Let the worker know we woke up and nothing interesting
215                // happened.
216                if let Err(e) = app_worker.worker_idle_wake(self.connected) {
217                    log::error!("worker_idle_wake() returned an error: {e}");
218                    break;
219                }
220            }
221
222            // Did we get a shutdown signal?  Check this after
223            // "end_session()" so we don't interrupt a conversation to
224            // shutdown.
225            if self.sig_tracker.any_shutdown_requested() {
226                log::info!("{selfstr} received a stop signal");
227                break;
228            }
229        }
230
231        log::debug!("{self} exiting listen loop and cleaning up");
232
233        // Tell the worker to cleanup
234        if let Err(e) = app_worker.worker_end() {
235            // Avoid treating this as a fatal worker error
236            log::error!("{selfstr} worker_end failed {e}");
237        }
238
239        self.set_exiting()?;
240        self.reset()
241    }
242
243    /// Call recv() on our message bus and process the response.
244    ///
245    /// Return value consists of (work_occurred, msg_handled).
246    fn handle_recv(
247        &mut self,
248        app_worker: &mut Box<dyn app::ApplicationWorker>,
249        timeout: u64,
250        sent_to: &str,
251    ) -> EgResult<(bool, bool)> {
252        let selfstr = format!("{self}");
253
254        let tmsg_op = self
255            .client_internal_mut()
256            .bus_mut()
257            .recv(timeout, Some(sent_to))?;
258
259        let tmsg = match tmsg_op {
260            Some(v) => v,
261            None => {
262                if !self.connected {
263                    // No new message to handle and no timeout to address.
264                    return Ok((false, false));
265                }
266
267                // Caller failed to send a message within the keepliave interval.
268                log::warn!("{selfstr} timeout waiting on request while connected");
269
270                self.reply_with_status(MessageStatus::Timeout, "Timeout")?;
271                self.set_active()?;
272
273                return Ok((true, false)); // work occurred
274            }
275        };
276
277        self.set_active()?;
278
279        if !self.connected {
280            // Any message received in a non-connected state represents
281            // the start of a session.  For stateful convos, the
282            // current message will be a CONNECT.  Otherwise, it will
283            // be a one-off request.
284            app_worker.start_session()?;
285        }
286
287        if let Err(e) = self.handle_transport_message(tmsg, app_worker) {
288            // An error within our worker's method handler is not enough
289            // to shut down the worker.  Log, force a disconnect on the
290            // session (if applicable) and move on.
291            log::error!("{selfstr} error handling message: {e}");
292            self.connected = false;
293        }
294
295        Ok((true, true)) // work occurred, message handled
296    }
297
298    /// Tell our parent we're about to perform some work.
299    fn set_active(&mut self) -> EgResult<()> {
300        if let Err(e) = self.notify_state(WorkerState::Active) {
301            Err(format!(
302                "{self} failed to notify parent of Active state. Exiting. {e}"
303            ))?;
304        }
305
306        Ok(())
307    }
308
309    /// Tell our parent we're available to perform work.
310    fn set_idle(&mut self) -> EgResult<()> {
311        if let Err(e) = self.notify_state(WorkerState::Idle) {
312            Err(format!(
313                "{self} failed to notify parent of Idle state. Exiting. {e}"
314            ))?;
315        }
316
317        Ok(())
318    }
319
320    /// Tell our parent we're available to perform work.
321    fn set_exiting(&mut self) -> EgResult<()> {
322        if let Err(e) = self.notify_state(WorkerState::Exiting) {
323            Err(format!(
324                "{self} failed to notify parent of Exiting state. Exiting. {e}"
325            ))?;
326        }
327
328        Ok(())
329    }
330
331    fn handle_transport_message(
332        &mut self,
333        mut tmsg: message::TransportMessage,
334        app_worker: &mut Box<dyn app::ApplicationWorker>,
335    ) -> EgResult<()> {
336        // Always adopt the log trace of an inbound API call.
337        Logger::set_log_trace(tmsg.osrf_xid());
338
339        if self.session.is_none() || self.session().thread().ne(tmsg.thread()) {
340            log::trace!("server: creating new server session for {}", tmsg.thread());
341
342            self.session = Some(ServerSession::new(
343                self.client.clone(),
344                &self.service,
345                tmsg.thread(),
346                0, // thread trace -- updated later as needed
347                BusAddress::parse_str(tmsg.from())?,
348            ));
349        }
350
351        for msg in tmsg.body_mut().drain(..) {
352            self.handle_message(msg, app_worker)?;
353        }
354
355        Ok(())
356    }
357
358    // Clear our local message bus and reset state maintenance values.
359    fn reset(&mut self) -> EgResult<()> {
360        self.connected = false;
361        self.session = None;
362        self.client.clear()
363    }
364
365    fn handle_message(
366        &mut self,
367        msg: message::Message,
368        app_worker: &mut Box<dyn app::ApplicationWorker>,
369    ) -> EgResult<()> {
370        self.session_mut().set_last_thread_trace(msg.thread_trace());
371        self.session_mut().clear_responded_complete();
372
373        log::trace!("{self} received message of type {:?}", msg.mtype());
374
375        match msg.mtype() {
376            message::MessageType::Disconnect => {
377                log::trace!("{self} received a DISCONNECT");
378                self.reset()?;
379                Ok(())
380            }
381
382            message::MessageType::Connect => {
383                log::trace!("{self} received a CONNECT");
384
385                if self.connected {
386                    return self.reply_bad_request("Worker is already connected");
387                }
388
389                self.connected = true;
390                self.reply_with_status(MessageStatus::Ok, "OK")
391            }
392
393            message::MessageType::Request => {
394                log::trace!("{self} received a REQUEST");
395                self.handle_request(msg, app_worker)
396            }
397
398            _ => self.reply_bad_request("Unexpected message type"),
399        }
400    }
401
402    fn reply_with_status(&mut self, stat: MessageStatus, stat_text: &str) -> EgResult<()> {
403        let tmsg = TransportMessage::with_body(
404            self.session().sender().as_str(),
405            self.client.address().as_str(),
406            self.session().thread(),
407            Message::new(
408                MessageType::Status,
409                self.session().last_thread_trace(),
410                Payload::Status(message::Status::new(stat, stat_text, "osrfStatus")),
411            ),
412        );
413
414        self.client_internal_mut()
415            .get_domain_bus(self.session().sender().domain())?
416            .send(tmsg)
417    }
418
419    fn handle_request(
420        &mut self,
421        mut msg: message::Message,
422        app_worker: &mut Box<dyn app::ApplicationWorker>,
423    ) -> EgResult<()> {
424        let method_call = match msg.take_payload() {
425            message::Payload::Method(m) => m,
426            _ => return self.reply_bad_request("Request sent without a MethoCall payload"),
427        };
428
429        let param_count = method_call.params().len();
430        let api_name = method_call.method().to_string();
431
432        let log_params = util::stringify_params(
433            &api_name,
434            method_call.params(),
435            conf::config().log_protect(),
436        );
437
438        // Log the API call
439        log::info!("CALL: {} {}", api_name, log_params);
440
441        // Before we begin processing a service-level request, clear our
442        // local message bus to avoid encountering any stale messages
443        // lingering from the previous conversation.
444        if !self.connected {
445            self.client.clear()?;
446        }
447
448        // Clone the method since we have mutable borrows below.  Note
449        // this is the method definition, not the param-laden request.
450        let mut method_def = Server::methods().get(&api_name).cloned();
451
452        if method_def.is_none() {
453            // Atomic methods are not registered/published in advance
454            // since every method has an atomic variant.
455            // Find the root method and use it.
456            if api_name.ends_with(".atomic") {
457                let meth = api_name.replace(".atomic", "");
458                if let Some(m) = Server::methods().get(&meth) {
459                    method_def = Some(m.clone());
460
461                    // Creating a new queue tells our session to treat
462                    // this as an atomic request.
463                    self.session_mut().new_atomic_resp_queue();
464                }
465            }
466        }
467
468        if method_def.is_none() {
469            log::warn!("Method not found: {}", api_name);
470
471            return self.reply_with_status(
472                MessageStatus::MethodNotFound,
473                &format!("Method not found: {}", api_name),
474            );
475        }
476
477        let method_def = method_def.unwrap();
478        let pcount = method_def.param_count();
479
480        // Make sure the number of params sent by the caller matches the
481        // parameter count for the method.
482        if !ParamCount::matches(pcount, param_count as u8) {
483            return self.reply_bad_request(&format!(
484                "Invalid param count sent: method={} sent={} needed={}",
485                api_name, param_count, &pcount,
486            ));
487        }
488
489        // Verify paramter types are correct, at least superficially.
490        // Do this after deserialization.
491        if let Some(param_defs) = method_def.params() {
492            for (idx, param_def) in param_defs.iter().enumerate() {
493                // There may be more param defs than parameters if
494                // some param are optional.
495                if let Some(param_val) = method_call.params().get(idx) {
496                    if idx >= pcount.minimum() as usize && param_val.is_null() {
497                        // NULL placeholders for non-required parameters are
498                        // allowed.
499                        continue;
500                    }
501                    if !param_def.datatype.matches(param_val) {
502                        return self.reply_bad_request(&format!(
503                            "Invalid paramter type: wanted={} got={}",
504                            param_def.datatype,
505                            param_val.clone().dump()
506                        ));
507                    }
508                } else {
509                    // More defs than actual params. Verification complete.
510                    break;
511                }
512            }
513        }
514
515        // Call the API
516        if let Err(err) = (method_def.handler())(app_worker, self.session_mut(), method_call) {
517            let msg = format!("{self} method {api_name} exited: \"{err}\"");
518            log::error!("{msg}");
519            app_worker.api_call_error(&api_name, err);
520            self.reply_server_error(&msg)?;
521            Err(msg)?;
522        }
523
524        if !self.session().responded_complete() {
525            self.session_mut().send_complete()
526        } else {
527            Ok(())
528        }
529    }
530
531    fn reply_server_error(&mut self, text: &str) -> EgResult<()> {
532        self.connected = false;
533
534        let msg = Message::new(
535            MessageType::Status,
536            self.session().last_thread_trace(),
537            Payload::Status(message::Status::new(
538                MessageStatus::InternalServerError,
539                &format!("Internal Server Error: {text}"),
540                "osrfStatus",
541            )),
542        );
543
544        let tmsg = TransportMessage::with_body(
545            self.session().sender().as_str(),
546            self.client.address().as_str(),
547            self.session().thread(),
548            msg,
549        );
550
551        self.client_internal_mut()
552            .get_domain_bus(self.session().sender().domain())?
553            .send(tmsg)
554    }
555
556    fn reply_bad_request(&mut self, text: &str) -> EgResult<()> {
557        self.connected = false;
558
559        let msg = Message::new(
560            MessageType::Status,
561            self.session().last_thread_trace(),
562            Payload::Status(message::Status::new(
563                MessageStatus::BadRequest,
564                &format!("Bad Request: {text}"),
565                "osrfStatus",
566            )),
567        );
568
569        let tmsg = TransportMessage::with_body(
570            self.session().sender().as_str(),
571            self.client.address().as_str(),
572            self.session().thread(),
573            msg,
574        );
575
576        self.client_internal_mut()
577            .get_domain_bus(self.session().sender().domain())?
578            .send(tmsg)
579    }
580
581    /// Notify the parent process of this worker's active state.
582    fn notify_state(&self, state: WorkerState) -> EgResult<()> {
583        log::trace!("{self} notifying parent of state change => {state:?}");
584
585        self.to_parent_tx
586            .send(WorkerStateEvent {
587                state,
588                worker_id: self.worker_id(),
589            })
590            .map_err(|e| format!("mpsc::SendError: {e}").into())
591    }
592}