Merge modifications for more separate backend functions #69
@@ -1,4 +1,6 @@
 | 
				
			|||||||
"""Module used for handling cache"""
 | 
					"""Module used for handling cache"""
 | 
				
			||||||
 | 
					import hashlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pymemcache import serde
 | 
					from pymemcache import serde
 | 
				
			||||||
from pymemcache.client.base import Client
 | 
					from pymemcache.client.base import Client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -73,3 +75,62 @@ else:
 | 
				
			|||||||
        encoding='utf-8',
 | 
					        encoding='utf-8',
 | 
				
			||||||
        serde=serde.pickle_serde
 | 
					        serde=serde.pickle_serde
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Cache for payment architecture
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def make_credit_cache_key(user_id: str, order_id: str) -> str:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Generate a cache key from user_id and order_id using md5.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        user_id (str): The user's ID.
 | 
				
			||||||
 | 
					        order_id (str): The PayPal order ID.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        str: A unique cache key.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Concatenate and hash to avoid collisions and keep key size small
 | 
				
			||||||
 | 
					    raw_key = f"{user_id}:{order_id}"
 | 
				
			||||||
 | 
					    return hashlib.md5(raw_key.encode('utf-8')).hexdigest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CreditCache:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Handles storing and retrieving credits to grant for a user/order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Methods:
 | 
				
			||||||
 | 
					        set_credits(user_id, order_id, credits):
 | 
				
			||||||
 | 
					            Store the credits for a user/order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        get_credits(user_id, order_id):
 | 
				
			||||||
 | 
					            Retrieve the credits for a user/order.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def set_credits(user_id: str, order_id: str, credits_to_grant: int) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Store the credits to be granted for a user/order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            user_id (str): The user's ID.
 | 
				
			||||||
 | 
					            order_id (str): The PayPal order ID.
 | 
				
			||||||
 | 
					            credits (int): The amount of credits to grant.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        cache_key = make_credit_cache_key(user_id, order_id)
 | 
				
			||||||
 | 
					        client.set(cache_key, credits_to_grant)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_credits(user_id: str, order_id: str) -> int | None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Retrieve the credits to be granted for a user/order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            user_id (str): The user's ID.
 | 
				
			||||||
 | 
					            order_id (str): The PayPal order ID.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            int | None: The credits to grant, or None if not found.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        cache_key = make_credit_cache_key(user_id, order_id)
 | 
				
			||||||
 | 
					        return client.get(cache_key)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -46,6 +46,8 @@ app.include_router(toilets_router)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Include the payment router for interacting with paypal sdk.
 | 
					# Include the payment router for interacting with paypal sdk.
 | 
				
			||||||
# See src/payment/payment_router.py for more information on how to call.
 | 
					# See src/payment/payment_router.py for more information on how to call.
 | 
				
			||||||
 | 
					# Call with "/orders/new" to initiate a payment with an order request (step 1)
 | 
				
			||||||
 | 
					# Call with "/orders/{order_id}/{user_id}capture" to capture a payment and grant the user the due credits (step 2)
 | 
				
			||||||
app.include_router(payment_router)
 | 
					app.include_router(payment_router)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -54,3 +56,4 @@ app.include_router(payment_router)
 | 
				
			|||||||
# Call with "/landmark/{landmark_uuid}" for getting landmark by UUID.
 | 
					# Call with "/landmark/{landmark_uuid}" for getting landmark by UUID.
 | 
				
			||||||
# Call with "/trip//trip/recompute-time/{trip_uuid}/{removed_landmark_uuid}" for updating trip times.
 | 
					# Call with "/trip//trip/recompute-time/{trip_uuid}/{removed_landmark_uuid}" for updating trip times.
 | 
				
			||||||
app.include_router(trips_router)
 | 
					app.include_router(trips_router)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,20 +1,22 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
from typing import Literal
 | 
					from typing import Literal
 | 
				
			||||||
from datetime import datetime, timedelta
 | 
					from datetime import datetime, timedelta
 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pydantic import BaseModel, Field, field_validator
 | 
					 | 
				
			||||||
import requests
 | 
					import requests
 | 
				
			||||||
 | 
					from pydantic import BaseModel, Field, field_validator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..configuration.environment import Environment
 | 
					from ..configuration.environment import Environment
 | 
				
			||||||
 | 
					from ..cache import CreditCache, make_credit_cache_key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Define the base URL, might move that to toml file
 | 
					 | 
				
			||||||
BASE_URL = 'https://api-m.sandbox.paypal.com'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Intialize the logger
 | 
					# Intialize the logger
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Define the base URL, might move that to toml file
 | 
				
			||||||
 | 
					BASE_URL_PROD = 'https://api-m.paypal.com'
 | 
				
			||||||
 | 
					BASE_URL_SANDBOX = 'https://api-m.sandbox.paypal.com'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BasketItem(BaseModel):
 | 
					class BasketItem(BaseModel):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@@ -42,7 +44,7 @@ class Item(BaseModel):
 | 
				
			|||||||
    name: str
 | 
					    name: str
 | 
				
			||||||
    description: str
 | 
					    description: str
 | 
				
			||||||
    unit_price: float
 | 
					    unit_price: float
 | 
				
			||||||
    anyway_credits: int
 | 
					    unit_credits: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def item_from_sql(item_id: str):
 | 
					def item_from_sql(item_id: str):
 | 
				
			||||||
@@ -60,8 +62,8 @@ def item_from_sql(item_id: str):
 | 
				
			|||||||
        id = '12345678',
 | 
					        id = '12345678',
 | 
				
			||||||
        name = 'test_item',
 | 
					        name = 'test_item',
 | 
				
			||||||
        description = 'lorem ipsum',
 | 
					        description = 'lorem ipsum',
 | 
				
			||||||
        unit_price = 420,
 | 
					        unit_price = 0.1,
 | 
				
			||||||
        anyway_credits = 99
 | 
					        unit_credits = 5
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -84,7 +86,8 @@ class OrderRequest(BaseModel):
 | 
				
			|||||||
    created_at: datetime = Field(default_factory=datetime.now)
 | 
					    created_at: datetime = Field(default_factory=datetime.now)
 | 
				
			||||||
    updated_at: datetime = Field(default_factory=datetime.now)
 | 
					    updated_at: datetime = Field(default_factory=datetime.now)
 | 
				
			||||||
    items: list[Item] = Field(default_factory=list)
 | 
					    items: list[Item] = Field(default_factory=list)
 | 
				
			||||||
    total_price: float = 0
 | 
					    total_price: float = None
 | 
				
			||||||
 | 
					    total_credits: int = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @field_validator('basket')
 | 
					    @field_validator('basket')
 | 
				
			||||||
    def validate_basket(cls, v):
 | 
					    def validate_basket(cls, v):
 | 
				
			||||||
@@ -104,15 +107,18 @@ class OrderRequest(BaseModel):
 | 
				
			|||||||
        return
 | 
					        return
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def load_items_and_price(self):
 | 
					    def load_items_and_price(self):
 | 
				
			||||||
 | 
					        # This should be automatic upon initialization of the class
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Loads item details from database and calculates the total price.
 | 
					        Loads item details from database and calculates the total price as well as the total credits to be granted.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.items = []
 | 
					        self.items = []
 | 
				
			||||||
        self.total_price = 0
 | 
					        self.total_price = 0
 | 
				
			||||||
 | 
					        self.total_credits = 0
 | 
				
			||||||
        for basket_item in self.basket:
 | 
					        for basket_item in self.basket:
 | 
				
			||||||
            item = item_from_sql(basket_item.id)
 | 
					            item = item_from_sql(basket_item.id)
 | 
				
			||||||
            self.items.append(item)
 | 
					            self.items.append(item)
 | 
				
			||||||
            self.total_price += item.unit_price * basket_item.quantity
 | 
					            self.total_price += item.unit_price * basket_item.quantity      # increment price
 | 
				
			||||||
 | 
					            self.total_credits += item.unit_credits * basket_item.quantity  # increment credit balance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_paypal_items(self): 
 | 
					    def to_paypal_items(self): 
 | 
				
			||||||
@@ -138,11 +144,8 @@ class OrderRequest(BaseModel):
 | 
				
			|||||||
        return item_list
 | 
					        return item_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Payment handler class for managing PayPal payments
 | 
					# Payment handler class for managing PayPal payments
 | 
				
			||||||
class PaypalHandler:
 | 
					class PaypalClient:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Handles PayPal payment operations.
 | 
					    Handles PayPal payment operations.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -159,7 +162,6 @@ class PaypalHandler:
 | 
				
			|||||||
        "expires_at": 0
 | 
					        "expires_at": 0
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    order_request = None
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
            self,
 | 
					            self,
 | 
				
			||||||
@@ -178,11 +180,11 @@ class PaypalHandler:
 | 
				
			|||||||
        if sandbox_mode :
 | 
					        if sandbox_mode :
 | 
				
			||||||
            self.id = Environment.paypal_id_sandbox
 | 
					            self.id = Environment.paypal_id_sandbox
 | 
				
			||||||
            self.key = Environment.paypal_key_sandbox
 | 
					            self.key = Environment.paypal_key_sandbox
 | 
				
			||||||
            self.base_url = BASE_URL
 | 
					            self.base_url = BASE_URL_SANDBOX
 | 
				
			||||||
        else :
 | 
					        else :
 | 
				
			||||||
            self.id = Environment.paypal_id_prod
 | 
					            self.id = Environment.paypal_id_prod
 | 
				
			||||||
            self.key = Environment.paypal_key_prod
 | 
					            self.key = Environment.paypal_key_prod
 | 
				
			||||||
            self.base_url = 'https://api-m.paypal.com'
 | 
					            self.base_url = BASE_URL_PROD
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -239,6 +241,11 @@ class PaypalHandler:
 | 
				
			|||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            dict | None: PayPal order response JSON, or None if failed.
 | 
					            dict | None: PayPal order response JSON, or None if failed.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Fetch details of order from mart database and compute total credits and price
 | 
				
			||||||
 | 
					        order_request.load_items_and_price()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Prepare payload for post request to paypal API
 | 
				
			||||||
        order_data = {
 | 
					        order_data = {
 | 
				
			||||||
            'intent': 'CAPTURE',
 | 
					            'intent': 'CAPTURE',
 | 
				
			||||||
            'purchase_units': [
 | 
					            'purchase_units': [
 | 
				
			||||||
@@ -286,20 +293,24 @@ class PaypalHandler:
 | 
				
			|||||||
        # order_id (key): json.loads(order_response.text)["id"]
 | 
					        # order_id (key): json.loads(order_response.text)["id"]
 | 
				
			||||||
        # user_id       : order_request.user_id
 | 
					        # user_id       : order_request.user_id
 | 
				
			||||||
        # created_at    : order_request.created_at
 | 
					        # created_at    : order_request.created_at
 | 
				
			||||||
        # status        : order_request.status
 | 
					        # status        : PENDING
 | 
				
			||||||
        # basket (json) : OrderDetails.jsonify()
 | 
					        # basket (json) : OrderDetails.jsonify()
 | 
				
			||||||
        # total_price   : order_request.total_price
 | 
					        # total_price   : order_request.total_price
 | 
				
			||||||
        # currency      : order_request.currency
 | 
					        # currency      : order_request.currency
 | 
				
			||||||
        # updated_at    : order_request.created_at
 | 
					        # updated_at    : order_request.created_at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Now we can increment the supabase balance by so many credits as in the balance.
 | 
					        # Create a cache item for credits to be granted to user
 | 
				
			||||||
        #TODO still
 | 
					        CreditCache.set_credits(
 | 
				
			||||||
 | 
					            user_id  = order_request.user_id,
 | 
				
			||||||
 | 
					            order_id = json.loads(order_response.text)["id"],
 | 
				
			||||||
 | 
					            credits_to_grant = order_request.total_credits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return order_response.json()
 | 
					        return order_response.json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Standalone function to capture a payment
 | 
					    # Standalone function to capture a payment
 | 
				
			||||||
    def capture(self, order_id: str):
 | 
					    def capture(self, user_id: str, order_id: str):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Captures payment for a PayPal order.
 | 
					        Captures payment for a PayPal order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -314,7 +325,7 @@ class PaypalHandler:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        try: 
 | 
					        try: 
 | 
				
			||||||
            capture_response = requests.post(
 | 
					            capture_response = requests.post(
 | 
				
			||||||
                url  = f'{BASE_URL}/v2/checkout/orders/{order_id}/capture',
 | 
					                url  = f'{self.base_url}/v2/checkout/orders/{order_id}/capture',
 | 
				
			||||||
                headers = {'Authorization': f'Bearer {access_token}'},
 | 
					                headers = {'Authorization': f'Bearer {access_token}'},
 | 
				
			||||||
                json = {},
 | 
					                json = {},
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
@@ -322,11 +333,12 @@ class PaypalHandler:
 | 
				
			|||||||
            logger.error(f'Error while requesting access token: {exc}')
 | 
					            logger.error(f'Error while requesting access token: {exc}')
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Raise exception if API call failed
 | 
				
			||||||
        capture_response.raise_for_status()
 | 
					        capture_response.raise_for_status()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # todo check status code + try except
 | 
					        
 | 
				
			||||||
        print(capture_response.text)
 | 
					
 | 
				
			||||||
        # order_id = json.loads(response.text)["id"]
 | 
					        # print(capture_response.text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO: update status to PAID in sql database
 | 
					        # TODO: update status to PAID in sql database
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,14 +1,24 @@
 | 
				
			|||||||
 | 
					import logging
 | 
				
			||||||
from typing import Literal
 | 
					from typing import Literal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from fastapi import FastAPI, HTTPException
 | 
					from fastapi import APIRouter, HTTPException
 | 
				
			||||||
from ..payments import PaypalHandler, OrderRequest
 | 
					from ..payments import PaypalClient, OrderRequest
 | 
				
			||||||
 | 
					from ..supabase.supabase import SupabaseClient
 | 
				
			||||||
 | 
					from ..cache import CreditCache, make_credit_cache_key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
app = FastAPI()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Initialize PayPal handler
 | 
					# Create a PayPal & Supabase client
 | 
				
			||||||
paypal_handler = PaypalHandler(sandbox_mode=True)
 | 
					paypal_client = PaypalClient(sandbox_mode=False)
 | 
				
			||||||
 | 
					supabase = SupabaseClient()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/orders/new")
 | 
					# Initialize the API router
 | 
				
			||||||
 | 
					router = APIRouter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Initialize the logger
 | 
				
			||||||
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@router.post("/orders/new")
 | 
				
			||||||
def create_order(
 | 
					def create_order(
 | 
				
			||||||
        user_id: str,
 | 
					        user_id: str,
 | 
				
			||||||
        basket: list,
 | 
					        basket: list,
 | 
				
			||||||
@@ -34,12 +44,12 @@ def create_order(
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Process the order and return the details
 | 
					    # Process the order and return the details
 | 
				
			||||||
    return paypal_handler.order(order_request = order)
 | 
					    return paypal_client.order(order_request = order)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/orders/{order_id}/capture")
 | 
					@router.post("/orders/{order_id}/{user_id}capture")
 | 
				
			||||||
def capture_order(order_id: str):
 | 
					def capture_order(order_id: str, user_id: str):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Captures payment for an existing PayPal order.
 | 
					    Captures payment for an existing PayPal order.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -49,7 +59,20 @@ def capture_order(order_id: str):
 | 
				
			|||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
        dict: The PayPal capture response.
 | 
					        dict: The PayPal capture response.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    result = paypal_handler.capture(order_id)
 | 
					    # Capture the payment
 | 
				
			||||||
 | 
					    result = paypal_client.capture(order_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Grant the user the correct amount of credits:
 | 
				
			||||||
 | 
					    credits = CreditCache.get_credits(user_id, order_id)
 | 
				
			||||||
 | 
					    if credits:
 | 
				
			||||||
 | 
					        supabase.increment_credit_balance(
 | 
				
			||||||
 | 
					            user_id=user_id,
 | 
				
			||||||
 | 
					            amount=credits
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        logger.info('Payment capture succeeded: incrementing balance of user {user_id} by {credits}.')
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        logger.error('Capture payment failed. Could not find cache key for user {user_id} and order {order_id}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return result
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +0,0 @@
 | 
				
			|||||||
"""This module contains the descriptions of items to be purchased in the AnyWay store."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							
								
								
									
										46
									
								
								backend/src/tests/test_payment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								backend/src/tests/test_payment.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					"""Collection of tests to ensure correct implementation and track progress of paypal payments."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from fastapi.testclient import TestClient
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..main import app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture(scope="module")
 | 
				
			||||||
 | 
					def client():
 | 
				
			||||||
 | 
					    """Client used to call the app."""
 | 
				
			||||||
 | 
					    return TestClient(app)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    "location,status_code",
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ([45.7576485, 4.8330241], 200),    # Lyon, France
 | 
				
			||||||
 | 
					        ([41.4020572, 2.1818985], 200),    # Barcelona, Spain
 | 
				
			||||||
 | 
					        ([59.3293, 18.0686], 200),         # Stockholm, Sweden
 | 
				
			||||||
 | 
					        ([43.6532, -79.3832], 200),        # Toronto, Canada
 | 
				
			||||||
 | 
					        ([38.7223, -9.1393], 200),         # Lisbon, Portugal
 | 
				
			||||||
 | 
					        ([6.5244, 3.3792], 200),           # Lagos, Nigeria
 | 
				
			||||||
 | 
					        ([17.3850, 78.4867], 200),         # Hyderabad, India
 | 
				
			||||||
 | 
					        ([30.0444, 31.2357], 200),         # Cairo, Egypt
 | 
				
			||||||
 | 
					        ([50.8503, 4.3517], 200),          # Brussels, Belgium
 | 
				
			||||||
 | 
					        ([35.2271, -80.8431], 200),        # Charlotte, USA
 | 
				
			||||||
 | 
					        ([10.4806, -66.9036], 200),        # Caracas, Venezuela
 | 
				
			||||||
 | 
					        ([9.51074, -13.71118], 200),       # Conakry, Guinea
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_nearby(client, location, status_code):    # pylint: disable=redefined-outer-name
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Test n°1 : Verify handling of invalid input.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        client:
 | 
				
			||||||
 | 
					        request:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    response = client.post(f"/get-nearby/landmarks/{location[0]}/{location[1]}")
 | 
				
			||||||
 | 
					    suggestions = response.json()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # checks :
 | 
				
			||||||
 | 
					    assert response.status_code == status_code  # check for successful planning
 | 
				
			||||||
 | 
					    assert isinstance(suggestions, list)  # check that the return type is a list
 | 
				
			||||||
 | 
					    assert len(suggestions) > 0
 | 
				
			||||||
		Reference in New Issue
	
	Block a user