diff --git a/smsbot/cli.py b/smsbot/cli.py index 3c12313..f1a651e 100644 --- a/smsbot/cli.py +++ b/smsbot/cli.py @@ -11,7 +11,7 @@ from asgiref.wsgi import WsgiToAsgi from smsbot.telegram import TelegramSmsBot from smsbot.utils import get_smsbot_version -from smsbot.webhook_handler import TwilioWebhookHandler +from smsbot.webhook import TwilioWebhookHandler def main(): @@ -70,7 +70,11 @@ def main(): for chat_id in config.get("telegram", "default_subscribers").split(","): telegram_bot.subscribers.append(int(chat_id.strip())) - webhooks = TwilioWebhookHandler() + # Init the webhook handler + webhooks = TwilioWebhookHandler( + account_sid=config.get("twilio", "account_sid", fallback=None), + auth_token=config.get("twilio", "auth_token", fallback=None), + ) webhooks.set_bot(telegram_bot) # Build a uvicorn ASGI server diff --git a/smsbot/webhook_handler.py b/smsbot/webhook.py similarity index 65% rename from smsbot/webhook_handler.py rename to smsbot/webhook.py index fe75d2d..0b81ca2 100644 --- a/smsbot/webhook_handler.py +++ b/smsbot/webhook.py @@ -16,47 +16,22 @@ MESSAGE_COUNT = Counter("webhook_message_count", "Total number of messages proce CALL_COUNT = Counter("webhook_call_count", "Total number of calls processed") -def validate_twilio_request(func): - """Validates that incoming requests genuinely originated from Twilio""" - - @wraps(func) - async def decorated_function(*args, **kwargs): - # Create an instance of the RequestValidator class - twilio_token = os.environ.get("SMSBOT_TWILIO_AUTH_TOKEN") - - if not twilio_token: - current_app.logger.warning( - "Twilio request validation skipped due to SMSBOT_TWILIO_AUTH_TOKEN missing" - ) - return await func(*args, **kwargs) - - validator = RequestValidator(twilio_token) - - # Validate the request using its URL, POST data, - # and X-TWILIO-SIGNATURE header - request_valid = validator.validate( - request.url, - request.form, - request.headers.get("X-TWILIO-SIGNATURE", ""), - ) - - # Continue processing the request if it's valid, return a 403 error if - # it's not - if request_valid or current_app.debug: - return func(*args, **kwargs) - return abort(403) - - return decorated_function - class TwilioWebhookHandler(object): - def __init__(self): + def __init__(self, account_sid: str | None = None, auth_token: str | None = None): self.app = Flask(self.__class__.__name__) self.app.add_url_rule("/", "index", self.index, methods=["GET"]) self.app.add_url_rule("/health", "health", self.health, methods=["GET"]) self.app.add_url_rule("/message", "message", self.message, methods=["POST"]) self.app.add_url_rule("/call", "call", self.call, methods=["POST"]) + self.account_sid = account_sid + self.auth_token = auth_token + + # Wrap validation around hook endpoints + self.message = self.validate_twilio_request(self.message) + self.call = self.validate_twilio_request(self.call) + # Add prometheus wsgi middleware to route /metrics requests self.app.wsgi_app = DispatcherMiddleware( self.app.wsgi_app, @@ -65,6 +40,35 @@ class TwilioWebhookHandler(object): }, ) + def validate_twilio_request(self, func): + """Validates that incoming requests genuinely originated from Twilio""" + + @wraps(func) + async def decorated_function(*args, **kwargs): + # Create an instance of the RequestValidator class + if not self.auth_token: + current_app.logger.warning( + "Twilio request validation skipped due to Twilio Auth Token missing" + ) + return await func(*args, **kwargs) + validator = RequestValidator(self.auth_token) + + # Validate the request using its URL, POST data, + # and X-TWILIO-SIGNATURE header + request_valid = validator.validate( + request.url, + request.form, + request.headers.get("X-TWILIO-SIGNATURE", ""), + ) + + # Continue processing the request if it's valid, return a 403 error if + # it's not + if request_valid or current_app.debug: + return await func(*args, **kwargs) + return abort(403) + + return decorated_function + def set_bot(self, bot): self.bot = bot @@ -80,7 +84,6 @@ class TwilioWebhookHandler(object): } @time(REQUEST_TIME) - @validate_twilio_request async def message(self): """Handle incoming SMS messages from Twilio""" current_app.logger.info( @@ -96,7 +99,6 @@ class TwilioWebhookHandler(object): return '' @time(REQUEST_TIME) - @validate_twilio_request async def call(self): """Handle incoming calls from Twilio""" current_app.logger.info(