Compare commits
20 Commits
bf8b64aacf
...
backend/mi
| Author | SHA1 | Date | |
|---|---|---|---|
| 51b7117c6d | |||
| 9c930996c7 | |||
| d9724ff07d | |||
| a884b9ee14 | |||
| bfc0c9adae | |||
| 510aabcb0a | |||
| fe1b42fff9 | |||
| b4cac3a357 | |||
| 54f541382e | |||
| 29ac462725 | |||
| d374dc333f | |||
| ab03cee3e3 | |||
| f86174bc11 | |||
| 3bdcdea850 | |||
| 5549f8b0e5 | |||
| b201dfe97c | |||
| b65d184f48 | |||
| 16b35ab5af | |||
| 011671832a | |||
| f2237bd721 |
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -2,7 +2,7 @@
|
||||
cache_XML/
|
||||
|
||||
# secrets
|
||||
*secrets.yaml
|
||||
*.env
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
FROM python:3.12-slim-bookworm
|
||||
# use python 3.12 as a base image
|
||||
FROM docker.io/python:3.12-alpine
|
||||
|
||||
# The installer requires curl (and certificates) to download the release archive
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends curl ca-certificates
|
||||
|
||||
# Download the latest installer
|
||||
ADD https://astral.sh/uv/install.sh /uv-installer.sh
|
||||
|
||||
# Run the installer then remove it
|
||||
RUN sh /uv-installer.sh && rm /uv-installer.sh
|
||||
|
||||
# Ensure the installed binary is on the `PATH`
|
||||
ENV PATH="/root/.local/bin/:$PATH"
|
||||
# use the latest version of uv, independently of the python version
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
def main():
|
||||
print("Hello from backend!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -10,6 +10,7 @@ dependencies = [
|
||||
"certifi==2024.12.14 ; python_full_version >= '3.6'",
|
||||
"charset-normalizer==3.4.1 ; python_full_version >= '3.7'",
|
||||
"click==8.1.8 ; python_full_version >= '3.7'",
|
||||
"dotenv>=0.9.9",
|
||||
"fastapi==0.115.7 ; python_full_version >= '3.8'",
|
||||
"fastapi-cli==0.0.7 ; python_full_version >= '3.8'",
|
||||
"h11==0.14.0 ; python_full_version >= '3.7'",
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Module used for handling cache"""
|
||||
import hashlib
|
||||
|
||||
from pymemcache import serde
|
||||
from pymemcache.client.base import Client
|
||||
|
||||
@@ -73,3 +75,62 @@ else:
|
||||
encoding='utf-8',
|
||||
serde=serde.pickle_serde
|
||||
)
|
||||
|
||||
|
||||
#### Cache for payment architecture
|
||||
|
||||
def make_credit_cache_key(user_id: str, order_id: str) -> str:
|
||||
"""
|
||||
Generate a cache key from user_id and order_id using md5.
|
||||
|
||||
Args:
|
||||
user_id (str): The user's ID.
|
||||
order_id (str): The PayPal order ID.
|
||||
|
||||
Returns:
|
||||
str: A unique cache key.
|
||||
"""
|
||||
# Concatenate and hash to avoid collisions and keep key size small
|
||||
raw_key = f"{user_id}:{order_id}"
|
||||
return hashlib.md5(raw_key.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
class CreditCache:
|
||||
"""
|
||||
Handles storing and retrieving credits to grant for a user/order.
|
||||
|
||||
Methods:
|
||||
set_credits(user_id, order_id, credits):
|
||||
Store the credits for a user/order.
|
||||
|
||||
get_credits(user_id, order_id):
|
||||
Retrieve the credits for a user/order.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def set_credits(user_id: str, order_id: str, credits_to_grant: int) -> None:
|
||||
"""
|
||||
Store the credits to be granted for a user/order.
|
||||
|
||||
Args:
|
||||
user_id (str): The user's ID.
|
||||
order_id (str): The PayPal order ID.
|
||||
credits (int): The amount of credits to grant.
|
||||
"""
|
||||
cache_key = make_credit_cache_key(user_id, order_id)
|
||||
client.set(cache_key, credits_to_grant)
|
||||
|
||||
@staticmethod
|
||||
def get_credits(user_id: str, order_id: str) -> int | None:
|
||||
"""
|
||||
Retrieve the credits to be granted for a user/order.
|
||||
|
||||
Args:
|
||||
user_id (str): The user's ID.
|
||||
order_id (str): The PayPal order ID.
|
||||
|
||||
Returns:
|
||||
int | None: The credits to grant, or None if not found.
|
||||
"""
|
||||
cache_key = make_credit_cache_key(user_id, order_id)
|
||||
return client.get(cache_key)
|
||||
|
||||
0
backend/src/configuration/__init__.py
Normal file
0
backend/src/configuration/__init__.py
Normal file
23
backend/src/configuration/environment.py
Normal file
23
backend/src/configuration/environment.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""This module is for loading variables from the environment and passes them throughout the code using the Environment dataclass"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
# Load variables from environment
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Environment :
|
||||
|
||||
# Load supabase secrets
|
||||
supabase_url = os.environ['SUPABASE_URL']
|
||||
supabase_admin_key = os.environ['SUPABASE_ADMIN_KEY']
|
||||
supabase_test_user_id = os.environ['SUPABASE_TEST_USER_ID']
|
||||
|
||||
# Load paypal secrets
|
||||
paypal_id_sandbox = os.environ['PAYPAL_ID_SANDBOX']
|
||||
paypal_key_sandbox = os.environ['PAYPAL_KEY_SANDBOX']
|
||||
@@ -22,7 +22,6 @@ class LandmarkManager:
|
||||
church_coeff: float # coeff to adjsut score of churches
|
||||
nature_coeff: float # coeff to adjust score of parks
|
||||
overall_coeff: float # coeff to adjust weight of tags
|
||||
# n_important: int # number of important landmarks to consider
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -41,7 +40,6 @@ class LandmarkManager:
|
||||
self.wikipedia_bonus = parameters['wikipedia_bonus']
|
||||
self.viewpoint_bonus = parameters['viewpoint_bonus']
|
||||
self.pay_bonus = parameters['pay_bonus']
|
||||
# self.n_important = parameters['N_important']
|
||||
|
||||
with OPTIMIZER_PARAMETERS_PATH.open('r') as f:
|
||||
parameters = yaml.safe_load(f)
|
||||
@@ -187,6 +185,7 @@ class LandmarkManager:
|
||||
|
||||
# caution, when applying a list of selectors, overpass will search for elements that match ALL selectors simultaneously
|
||||
# we need to split the selectors into separate queries and merge the results
|
||||
# TODO: this can be multi-threaded once the Overpass rate-limit is not a problem anymore
|
||||
for sel in dict_to_selector_list(amenity_selector):
|
||||
# self.logger.debug(f"Current selector: {sel}")
|
||||
|
||||
|
||||
@@ -64,9 +64,10 @@ def get_landmarks(
|
||||
|
||||
@router.post("/get-nearby/landmarks/{lat}/{lon}")
|
||||
def get_landmarks_nearby(
|
||||
lat: float,
|
||||
lon: float
|
||||
) -> list[Landmark] :
|
||||
lat: float,
|
||||
lon: float,
|
||||
allow_clusters: bool = False
|
||||
) -> list[Landmark] :
|
||||
"""
|
||||
Suggests nearby landmarks based on a given latitude and longitude.
|
||||
|
||||
@@ -76,6 +77,7 @@ def get_landmarks_nearby(
|
||||
Args:
|
||||
lat (float): Latitude of the user's current location.
|
||||
lon (float): Longitude of the user's current location.
|
||||
allow_clusters (bool): Whether or not to allow the search for shopping/historical clusters when looking for nearby landmarks.
|
||||
|
||||
Returns:
|
||||
list[Landmark]: A list of selected nearby landmarks.
|
||||
@@ -104,7 +106,7 @@ def get_landmarks_nearby(
|
||||
landmarks_around = manager.generate_landmarks_list(
|
||||
center_coordinates = (lat, lon),
|
||||
preferences = prefs,
|
||||
allow_clusters=False,
|
||||
allow_clusters=allow_clusters,
|
||||
)
|
||||
|
||||
if len(landmarks_around) == 0 :
|
||||
|
||||
@@ -1,28 +1,18 @@
|
||||
"""Main app for backend api"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI
|
||||
|
||||
from .logging_config import configure_logging
|
||||
from .structs.landmark import Landmark
|
||||
from .structs.linked_landmarks import LinkedLandmarks
|
||||
from .structs.trip import Trip
|
||||
from .landmarks.landmarks_manager import LandmarkManager
|
||||
from .toilets.toilets_router import router as toilets_router
|
||||
from .optimization.optimization_router import router as optimization_router
|
||||
from .landmarks.landmarks_router import router as landmarks_router
|
||||
from .payments.payment_router import router as payment_router
|
||||
from .optimization.optimizer import Optimizer
|
||||
from .optimization.refiner import Refiner
|
||||
from .cache import client as cache_client
|
||||
from .trips.trips_router import router as trips_router
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
manager = LandmarkManager()
|
||||
optimizer = Optimizer()
|
||||
refiner = Refiner(optimizer=optimizer)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -33,6 +23,7 @@ async def lifespan(app: FastAPI):
|
||||
logger.info("Shutting down logging")
|
||||
|
||||
|
||||
# Create the fastapi app
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@@ -52,85 +43,17 @@ app.include_router(optimization_router)
|
||||
# Call with "/get/toilets" for fetching toilets around coordinates.
|
||||
app.include_router(toilets_router)
|
||||
|
||||
|
||||
# Include the payment router for interacting with paypal sdk.
|
||||
# See src/payment/payment_router.py for more information on how to call.
|
||||
# Call with "/orders/new" to initiate a payment with an order request (step 1)
|
||||
# Call with "/orders/{order_id}/{user_id}capture" to capture a payment and grant the user the due credits (step 2)
|
||||
app.include_router(payment_router)
|
||||
|
||||
#### For already existing trips/landmarks
|
||||
@app.get("/trip/{trip_uuid}")
|
||||
def get_trip(trip_uuid: str) -> Trip:
|
||||
"""
|
||||
Look-up the cache for a trip that has been previously generated using its identifier.
|
||||
|
||||
Args:
|
||||
trip_uuid (str) : unique identifier for a trip.
|
||||
|
||||
Returns:
|
||||
(Trip) : the corresponding trip.
|
||||
"""
|
||||
try:
|
||||
trip = cache_client.get(f"trip_{trip_uuid}")
|
||||
return trip
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to fetch trip with UUID {trip_uuid}: {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail="Trip not found") from exc
|
||||
|
||||
|
||||
@app.get("/landmark/{landmark_uuid}")
|
||||
def get_landmark(landmark_uuid: str) -> Landmark:
|
||||
"""
|
||||
Returns a Landmark from its unique identifier.
|
||||
|
||||
Args:
|
||||
landmark_uuid (str) : unique identifier for a Landmark.
|
||||
|
||||
Returns:
|
||||
(Landmark) : the corresponding Landmark.
|
||||
"""
|
||||
try:
|
||||
landmark = cache_client.get(f"landmark_{landmark_uuid}")
|
||||
return landmark
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to fetch landmark with UUID {landmark_uuid}: {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail="Landmark not found") from exc
|
||||
|
||||
|
||||
@app.post("/trip/recompute-time/{trip_uuid}/{removed_landmark_uuid}")
|
||||
def update_trip_time(trip_uuid: str, removed_landmark_uuid: str) -> Trip:
|
||||
"""
|
||||
Updates the reaching times of a given trip when removing a landmark.
|
||||
|
||||
Args:
|
||||
landmark_uuid (str) : unique identifier for a Landmark.
|
||||
|
||||
Returns:
|
||||
(Landmark) : the corresponding Landmark.
|
||||
"""
|
||||
# First, fetch the trip in the cache.
|
||||
try:
|
||||
trip = cache_client.get(f'trip_{trip_uuid}')
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to update trip with UUID {trip_uuid} (trip not found): {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail='Trip not found') from exc
|
||||
|
||||
landmarks = []
|
||||
next_uuid = trip.first_landmark_uuid
|
||||
|
||||
# Extract landmarks
|
||||
try :
|
||||
while next_uuid is not None:
|
||||
landmark = cache_client.get(f'landmark_{next_uuid}')
|
||||
# Filter out the removed landmark.
|
||||
if next_uuid != removed_landmark_uuid :
|
||||
landmarks.append(landmark)
|
||||
next_uuid = landmark.next_uuid # Prepare for the next iteration
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to update trip with UUID {trip_uuid} : {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail=f'landmark {next_uuid} not found') from exc
|
||||
|
||||
# Re-link every thing and compute times again
|
||||
linked_tour = LinkedLandmarks(landmarks)
|
||||
trip = Trip.from_linked_landmarks(linked_tour, cache_client)
|
||||
|
||||
return trip
|
||||
# Endpoint for putting together a trip, fetching landmarks by UUID and updating trip times. Three routes
|
||||
# Call with "/trip/{trip_uuid}" for getting trip by UUID.
|
||||
# Call with "/landmark/{landmark_uuid}" for getting landmark by UUID.
|
||||
# Call with "/trip//trip/recompute-time/{trip_uuid}/{removed_landmark_uuid}" for updating trip times.
|
||||
app.include_router(trips_router)
|
||||
|
||||
|
||||
@@ -1,70 +1,357 @@
|
||||
from typing import Literal
|
||||
import paypalrestsdk
|
||||
from pydantic import BaseModel
|
||||
from fastapi import HTTPException
|
||||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from ..configuration.environment import Environment
|
||||
from ..cache import CreditCache, make_credit_cache_key
|
||||
|
||||
|
||||
# Model for payment request body
|
||||
class PaymentRequest(BaseModel):
|
||||
# Intialize the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define the base URL, might move that to toml file
|
||||
BASE_URL_PROD = 'https://api-m.paypal.com'
|
||||
BASE_URL_SANDBOX = 'https://api-m.sandbox.paypal.com'
|
||||
|
||||
|
||||
class BasketItem(BaseModel):
|
||||
"""
|
||||
Represents a single item in the user's basket.
|
||||
|
||||
Attributes:
|
||||
id (str): The unique identifier for the item.
|
||||
quantity (int): The number of units of the item.
|
||||
"""
|
||||
id: str
|
||||
quantity: int
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
"""
|
||||
Represents an item available in the shop.
|
||||
|
||||
Attributes:
|
||||
id (str): The unique identifier for the item.
|
||||
name (str): The name of the item.
|
||||
description (str): The description of the item.
|
||||
unit_price (float): The unit price of the item.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
unit_price: float
|
||||
unit_credits: int
|
||||
|
||||
|
||||
def item_from_sql(item_id: str):
|
||||
"""
|
||||
Fetches an item from the database by its ID.
|
||||
|
||||
Args:
|
||||
item_id (str): The unique identifier for the item.
|
||||
|
||||
Returns:
|
||||
Item: The item object retrieved from the database.
|
||||
"""
|
||||
# TODO: Replace with actual SQL fetch logic
|
||||
return Item(
|
||||
id = '12345678',
|
||||
name = 'test_item',
|
||||
description = 'lorem ipsum',
|
||||
unit_price = 0.1,
|
||||
unit_credits = 5
|
||||
)
|
||||
|
||||
|
||||
class OrderRequest(BaseModel):
|
||||
"""
|
||||
Represents an order request from the frontend.
|
||||
|
||||
Attributes:
|
||||
user_id (str): The ID of the user placing the order.
|
||||
basket (list[BasketItem]): List of basket items.
|
||||
currency (str): The currency code for the order.
|
||||
created_at (datetime): Timestamp when the order was created.
|
||||
updated_at (datetime): Timestamp when the order was last updated.
|
||||
items (list[Item]): List of item details loaded from the database.
|
||||
total_price (float): Total price of the order.
|
||||
"""
|
||||
user_id: str
|
||||
credit_amount: Literal[10, 50, 100]
|
||||
currency: Literal["USD", "EUR", "CHF"]
|
||||
description: str = "Purchase of credits"
|
||||
basket: list[BasketItem]
|
||||
currency: Literal['CHF', 'EUR', 'USD']
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
items: list[Item] = Field(default_factory=list)
|
||||
total_price: float = None
|
||||
total_credits: int = None
|
||||
|
||||
@field_validator('basket')
|
||||
def validate_basket(cls, v):
|
||||
"""Validates the basket items.
|
||||
|
||||
Args:
|
||||
v (list): List of basket items.
|
||||
|
||||
Raises:
|
||||
ValueError: If basket does not contain valid BasketItem objects.
|
||||
|
||||
Returns:
|
||||
list: The validated basket.
|
||||
"""
|
||||
if not v or not all(isinstance(i, BasketItem) for i in v):
|
||||
raise ValueError('Basket must contain BasketItem objects')
|
||||
return
|
||||
|
||||
def load_items_and_price(self):
|
||||
# This should be automatic upon initialization of the class
|
||||
"""
|
||||
Loads item details from database and calculates the total price as well as the total credits to be granted.
|
||||
"""
|
||||
self.items = []
|
||||
self.total_price = 0
|
||||
self.total_credits = 0
|
||||
for basket_item in self.basket:
|
||||
item = item_from_sql(basket_item.id)
|
||||
self.items.append(item)
|
||||
self.total_price += item.unit_price * basket_item.quantity # increment price
|
||||
self.total_credits += item.unit_credits * basket_item.quantity # increment credit balance
|
||||
|
||||
|
||||
def to_paypal_items(self):
|
||||
"""
|
||||
Converts items to the PayPal API item format.
|
||||
|
||||
Returns:
|
||||
list: List of items formatted for PayPal API.
|
||||
"""
|
||||
item_list = []
|
||||
|
||||
for basket_item, item in zip(self.basket, self.items):
|
||||
item_list.append({
|
||||
'id': item.id,
|
||||
'name': item.name,
|
||||
'description': item.description,
|
||||
'quantity': str(basket_item.quantity),
|
||||
'unit_amount': {
|
||||
'currency_code': self.currency,
|
||||
'value': str(item.unit_price)
|
||||
}
|
||||
})
|
||||
return item_list
|
||||
|
||||
|
||||
# Payment handler class for managing PayPal payments
|
||||
class PaymentHandler:
|
||||
class PaypalClient:
|
||||
"""
|
||||
Handles PayPal payment operations.
|
||||
|
||||
payment_id: str
|
||||
Attributes:
|
||||
sandbox (bool): Whether to use the sandbox environment.
|
||||
id (str): PayPal client ID.
|
||||
key (str): PayPal client secret.
|
||||
base_url (str): Base URL for PayPal API.
|
||||
_token_cache (dict): Cache for the PayPal OAuth access token.
|
||||
"""
|
||||
|
||||
def __init__(self, transaction_details: PaymentRequest):
|
||||
self.details = transaction_details
|
||||
_token_cache = {
|
||||
"access_token": None,
|
||||
"expires_at": 0
|
||||
}
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sandbox_mode: bool = False
|
||||
):
|
||||
"""
|
||||
Initializes the handler.
|
||||
|
||||
Args:
|
||||
sandbox_mode (bool): Whether to use sandbox credentials.
|
||||
"""
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.sandbox = sandbox_mode
|
||||
|
||||
# Only support purchase of credit 'bundles': 10, 50 or 100 credits worth of trip generation
|
||||
def fetch_price(self) -> float:
|
||||
# PayPal keys
|
||||
if sandbox_mode :
|
||||
self.id = Environment.paypal_id_sandbox
|
||||
self.key = Environment.paypal_key_sandbox
|
||||
self.base_url = BASE_URL_SANDBOX
|
||||
else :
|
||||
self.id = Environment.paypal_id_prod
|
||||
self.key = Environment.paypal_key_prod
|
||||
self.base_url = BASE_URL_PROD
|
||||
|
||||
|
||||
|
||||
def _get_access_token(self) -> str | None:
|
||||
"""
|
||||
Fetches the price of credits in the specified currency.
|
||||
Gets (and caches) a PayPal access token.
|
||||
|
||||
Returns:
|
||||
str | None: The access token if successful, None otherwise.
|
||||
"""
|
||||
result = self.supabase.table("prices").select("credit_amount").eq("currency", self.details.currency).single().execute()
|
||||
if result.data:
|
||||
return result.data.get("price")
|
||||
else:
|
||||
self.logger.error(f"Unsupported currency: {self.details.currency}")
|
||||
now = datetime.now()
|
||||
# Check if token is still valid
|
||||
if (
|
||||
self._token_cache["access_token"] is not None
|
||||
and self._token_cache["expires_at"] > now
|
||||
):
|
||||
self.logger.info('Returning (cached) access token.')
|
||||
return self._token_cache["access_token"]
|
||||
|
||||
# Request new token
|
||||
validation_data = {'grant_type': 'client_credentials'}
|
||||
|
||||
try:
|
||||
# pass the request
|
||||
validation_response = requests.post(
|
||||
url = f'{self.base_url}/v1/oauth2/token',
|
||||
data = validation_data,
|
||||
auth =(self.id, self.key)
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
self.logger.error(f'Error while requesting access token: {exc}')
|
||||
return None
|
||||
|
||||
def create_paypal_payment(self) -> str:
|
||||
data = validation_response.json()
|
||||
access_token = data.get("access_token")
|
||||
expires_in = int(data.get("expires_in", 3600)) # seconds, default 1 hour
|
||||
|
||||
# Cache the token and its expiry
|
||||
self._token_cache["access_token"] = access_token
|
||||
self._token_cache["expires_at"] = now + timedelta(seconds=expires_in - 60) # buffer 1 min
|
||||
|
||||
self.logger.info('Returning (new) access token.')
|
||||
return access_token
|
||||
|
||||
|
||||
def order(
|
||||
self,
|
||||
order_request: OrderRequest,
|
||||
return_url_success: str,
|
||||
return_url_failure: str
|
||||
):
|
||||
"""
|
||||
Creates a PayPal payment and returns the approval URL.
|
||||
Creates a new PayPal order.
|
||||
|
||||
Args:
|
||||
order_request (OrderRequest): The order request.
|
||||
|
||||
Returns:
|
||||
dict | None: PayPal order response JSON, or None if failed.
|
||||
"""
|
||||
price = self.fetch_price()
|
||||
payment = paypalrestsdk.Payment({
|
||||
"intent": "sale",
|
||||
"payer": {
|
||||
"payment_method": "paypal"
|
||||
},
|
||||
"transactions": [{
|
||||
"amount": {
|
||||
"total": f"{price:.2f}",
|
||||
"currency": self.details.currency
|
||||
},
|
||||
"description": self.details.description
|
||||
}],
|
||||
"redirect_urls": {
|
||||
"return_url": "http://localhost:8000/payment/success",
|
||||
"cancel_url": "http://localhost:8000/payment/cancel"
|
||||
|
||||
# Fetch details of order from mart database and compute total credits and price
|
||||
order_request.load_items_and_price()
|
||||
|
||||
# Prepare payload for post request to paypal API
|
||||
order_data = {
|
||||
'intent': 'CAPTURE',
|
||||
'purchase_units': [
|
||||
{
|
||||
'items': order_request.to_paypal_items(),
|
||||
'amount': {
|
||||
'currency_code': order_request.currency,
|
||||
'value': str(order_request.total_price),
|
||||
'breakdown': {
|
||||
'item_total': {
|
||||
'currency_code': order_request.currency,
|
||||
'value': str(order_request.total_price)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
'application_context': {
|
||||
'return_url': return_url_success,
|
||||
'cancel_url': return_url_failure
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if payment.create():
|
||||
self.logger.info("Payment created successfully")
|
||||
self.payment_id = payment.id
|
||||
# Get the access_token:
|
||||
access_token = self._get_access_token()
|
||||
|
||||
# Get the approval URL and return it for the user to approve
|
||||
for link in payment.links:
|
||||
if link.rel == "approval_url":
|
||||
return link.href
|
||||
else:
|
||||
self.logger.error(f"Failed to create payment: {payment.error}")
|
||||
raise HTTPException(status_code=500, detail="Payment creation failed")
|
||||
try:
|
||||
order_response = requests.post(
|
||||
url = f'{self.base_url}/v2/checkout/orders',
|
||||
headers = {'Authorization': f'Bearer {access_token}'},
|
||||
json = order_data,
|
||||
)
|
||||
|
||||
# Raise HTTP Exception if request was unsuccessful.
|
||||
except Exception as exc:
|
||||
self.logger.error(f'Error creating PayPal order: {exc}')
|
||||
return None
|
||||
|
||||
order_response.raise_for_status()
|
||||
|
||||
# TODO Now that we have the order ID, we can inscribe the details in sql database using the order id given by paypal
|
||||
# DB for storing the transactions:
|
||||
|
||||
# order_id (key): json.loads(order_response.text)["id"]
|
||||
# user_id : order_request.user_id
|
||||
# created_at : order_request.created_at
|
||||
# status : PENDING
|
||||
# basket (json) : OrderDetails.jsonify()
|
||||
# total_price : order_request.total_price
|
||||
# currency : order_request.currency
|
||||
# updated_at : order_request.created_at
|
||||
|
||||
# Create a cache item for credits to be granted to user
|
||||
CreditCache.set_credits(
|
||||
user_id = order_request.user_id,
|
||||
order_id = json.loads(order_response.text)["id"],
|
||||
credits_to_grant = order_request.total_credits)
|
||||
|
||||
|
||||
return order_response.json()
|
||||
|
||||
|
||||
# Standalone function to capture a payment
|
||||
def capture(self, user_id: str, order_id: str):
|
||||
"""
|
||||
Captures payment for a PayPal order.
|
||||
|
||||
Args:
|
||||
order_id (str): The PayPal order ID.
|
||||
|
||||
Returns:
|
||||
dict | None: PayPal capture response JSON, or None if failed.
|
||||
"""
|
||||
# Get the access_token:
|
||||
access_token = self._get_access_token()
|
||||
|
||||
try:
|
||||
capture_response = requests.post(
|
||||
url = f'{self.base_url}/v2/checkout/orders/{order_id}/capture',
|
||||
headers = {'Authorization': f'Bearer {access_token}'},
|
||||
json = {},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f'Error while requesting access token: {exc}')
|
||||
return None
|
||||
|
||||
# Raise exception if API call failed
|
||||
capture_response.raise_for_status()
|
||||
|
||||
|
||||
|
||||
# print(capture_response.text)
|
||||
|
||||
# TODO: update status to PAID in sql database
|
||||
|
||||
# where order_id (key) = order_id
|
||||
# status : 'PAID'
|
||||
# updated_at : datetime.now()
|
||||
|
||||
|
||||
# Not sure yet if/how to implement that
|
||||
def cancel(self):
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,79 +1,162 @@
|
||||
import logging
|
||||
import paypalrestsdk
|
||||
from fastapi import HTTPException, APIRouter
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Query, Body
|
||||
from ..payments import PaypalClient, OrderRequest
|
||||
from ..supabase.supabase import SupabaseClient
|
||||
from .payment_handler import PaymentRequest, PaymentHandler
|
||||
from ..cache import CreditCache, make_credit_cache_key
|
||||
|
||||
# Set up logging and supabase
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create a PayPal & Supabase client
|
||||
paypal_client = PaypalClient(sandbox_mode=False)
|
||||
supabase = SupabaseClient()
|
||||
|
||||
# Configure PayPal SDK
|
||||
paypalrestsdk.configure({
|
||||
"mode": "sandbox", # Use 'live' for production
|
||||
"client_id": "YOUR_PAYPAL_CLIENT_ID",
|
||||
"client_secret": "YOUR_PAYPAL_SECRET"
|
||||
})
|
||||
|
||||
|
||||
# Define the API router
|
||||
# Initialize the API router
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/purchase/credits")
|
||||
def purchase_credits(payment_request: PaymentRequest):
|
||||
# Initialize the logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: add the return url in the API payload to redirect the user to the app.
|
||||
@router.post("/orders/new")
|
||||
def create_order(
|
||||
user_id: str = Query(...),
|
||||
basket: list = Query(...),
|
||||
currency: str = Query(...),
|
||||
return_url_success: str = Query('https://anydev.info'),
|
||||
return_url_failure: str = Query('https://anydev.info')
|
||||
):
|
||||
"""
|
||||
Handles token purchases. Calculates the number of tokens based on the amount paid,
|
||||
updates the user's balance, and processes PayPal payment.
|
||||
Creates a new PayPal order.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user placing the order.
|
||||
basket (list): The basket items.
|
||||
currency (str): The currency code.
|
||||
|
||||
Returns:
|
||||
dict: The PayPal order details.
|
||||
"""
|
||||
payment_handler = PaymentHandler(payment_request)
|
||||
|
||||
# Create PayPal payment and get the approval URL
|
||||
approval_url = payment_handler.create_paypal_payment()
|
||||
# Create order :
|
||||
order = OrderRequest(
|
||||
user_id = user_id,
|
||||
basket=basket,
|
||||
currency=currency
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Purchase initiated successfully",
|
||||
"payment_id": payment_handler.payment_id,
|
||||
"credits": payment_request.credit_amount,
|
||||
"approval_url": approval_url,
|
||||
}
|
||||
# Process the order and return the details
|
||||
return paypal_client.order(order_request=order, return_url_success=return_url_success, return_url_failure=return_url_failure)
|
||||
|
||||
|
||||
@router.get("/payment/success")
|
||||
def payment_success(paymentId: str, PayerID: str):
|
||||
|
||||
@router.post("/orders/{order_id}/{user_id}capture")
|
||||
def capture_order(order_id: str, user_id: str):
|
||||
"""
|
||||
Handles successful PayPal payment.
|
||||
Captures payment for an existing PayPal order.
|
||||
|
||||
Args:
|
||||
order_id (str): The PayPal order ID.
|
||||
|
||||
Returns:
|
||||
dict: The PayPal capture response.
|
||||
"""
|
||||
payment = paypalrestsdk.Payment.find(paymentId)
|
||||
# Capture the payment
|
||||
result = paypal_client.capture(order_id)
|
||||
|
||||
if payment.execute({"payer_id": PayerID}):
|
||||
logger.info("Payment executed successfully")
|
||||
|
||||
# Retrieve transaction details from the database
|
||||
result = supabase.table("pending_payments").select("*").eq("payment_id", paymentId).single().execute()
|
||||
if not result.data:
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
# Extract the necessary information
|
||||
user_id = result.data["user_id"]
|
||||
credit_amount = result.data["credit_amount"]
|
||||
|
||||
# Update the user's balance
|
||||
supabase.increment_credit_balance(user_id, amount=credit_amount)
|
||||
|
||||
# Optionally, delete the pending payment entry since the transaction is completed
|
||||
supabase.table("pending_payments").delete().eq("payment_id", paymentId).execute()
|
||||
|
||||
return {"message": "Payment completed successfully"}
|
||||
# Grant the user the correct amount of credits:
|
||||
credits = CreditCache.get_credits(user_id, order_id)
|
||||
if credits:
|
||||
supabase.increment_credit_balance(
|
||||
user_id=user_id,
|
||||
amount=credits
|
||||
)
|
||||
logger.info('Payment capture succeeded: incrementing balance of user {user_id} by {credits}.')
|
||||
else:
|
||||
logger.error(f"Payment execution failed: {payment.error}")
|
||||
raise HTTPException(status_code=500, detail="Payment execution failed")
|
||||
logger.error('Capture payment failed. Could not find cache key for user {user_id} and order {order_id}')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/payment/cancel")
|
||||
def payment_cancel():
|
||||
"""
|
||||
Handles PayPal payment cancellation.
|
||||
"""
|
||||
return {"message": "Payment was cancelled"}
|
||||
|
||||
|
||||
# import logging
|
||||
# import paypalrestsdk
|
||||
# from fastapi import HTTPException, APIRouter
|
||||
|
||||
# from ..supabase.supabase import SupabaseClient
|
||||
# from .payment_handler import PaymentRequest, PaymentHandler
|
||||
|
||||
# # Set up logging and supabase
|
||||
# logger = logging.getLogger(__name__)
|
||||
# supabase = SupabaseClient()
|
||||
|
||||
# # Configure PayPal SDK
|
||||
# paypalrestsdk.configure({
|
||||
# "mode": "sandbox", # Use 'live' for production
|
||||
# "client_id": "YOUR_PAYPAL_CLIENT_ID",
|
||||
# "client_secret": "YOUR_PAYPAL_SECRET"
|
||||
# })
|
||||
|
||||
|
||||
# # Define the API router
|
||||
# router = APIRouter()
|
||||
|
||||
# @router.post("/purchase/credits")
|
||||
# def purchase_credits(payment_request: PaymentRequest):
|
||||
# """
|
||||
# Handles token purchases. Calculates the number of tokens based on the amount paid,
|
||||
# updates the user's balance, and processes PayPal payment.
|
||||
# """
|
||||
# payment_handler = PaymentHandler(payment_request)
|
||||
|
||||
# # Create PayPal payment and get the approval URL
|
||||
# approval_url = payment_handler.create_paypal_payment()
|
||||
|
||||
# return {
|
||||
# "message": "Purchase initiated successfully",
|
||||
# "payment_id": payment_handler.payment_id,
|
||||
# "credits": payment_request.credit_amount,
|
||||
# "approval_url": approval_url,
|
||||
# }
|
||||
|
||||
|
||||
# @router.get("/payment/success")
|
||||
# def payment_success(paymentId: str, PayerID: str):
|
||||
# """
|
||||
# Handles successful PayPal payment.
|
||||
# """
|
||||
# payment = paypalrestsdk.Payment.find(paymentId)
|
||||
|
||||
# if payment.execute({"payer_id": PayerID}):
|
||||
# logger.info("Payment executed successfully")
|
||||
|
||||
# # Retrieve transaction details from the database
|
||||
# result = supabase.table("pending_payments").select("*").eq("payment_id", paymentId).single().execute()
|
||||
# if not result.data:
|
||||
# raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
# # Extract the necessary information
|
||||
# user_id = result.data["user_id"]
|
||||
# credit_amount = result.data["credit_amount"]
|
||||
|
||||
# # Update the user's balance
|
||||
# supabase.increment_credit_balance(user_id, amount=credit_amount)
|
||||
|
||||
# # Optionally, delete the pending payment entry since the transaction is completed
|
||||
# supabase.table("pending_payments").delete().eq("payment_id", paymentId).execute()
|
||||
|
||||
# return {"message": "Payment completed successfully"}
|
||||
# else:
|
||||
# logger.error(f"Payment execution failed: {payment.error}")
|
||||
# raise HTTPException(status_code=500, detail="Payment execution failed")
|
||||
|
||||
|
||||
# @router.get("/payment/cancel")
|
||||
# def payment_cancel():
|
||||
# """
|
||||
# Handles PayPal payment cancellation.
|
||||
# """
|
||||
# return {"message": "Payment was cancelled"}
|
||||
|
||||
|
||||
111
backend/src/payments/test.py
Normal file
111
backend/src/payments/test.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#%%
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# username and password
|
||||
load_dotenv(override=True)
|
||||
username = os.environ['PAYPAL_ID_SANDBOX']
|
||||
password = os.environ['PAYPAL_KEY_SANDBOX']
|
||||
|
||||
|
||||
# DOCUMENTATION AT : https://developer.paypal.com/api/rest/requests/
|
||||
|
||||
|
||||
#%%
|
||||
######## STEP 1: Validation ########
|
||||
# url for validation post request
|
||||
validation_url = "https://api-m.sandbox.paypal.com/v1/oauth2/token"
|
||||
validation_url_prod = "https://api-m.paypal.com/v1/oauth2/token"
|
||||
|
||||
# payload for the post request
|
||||
validation_data = {'grant_type': 'client_credentials'}
|
||||
|
||||
# pass the request
|
||||
validation_response = requests.post(
|
||||
url=validation_url,
|
||||
data=validation_data,
|
||||
auth=(username, password)
|
||||
)
|
||||
|
||||
# todo check status code + try except. Status code 201 ?
|
||||
print(f'Reponse status code: {validation_response.status_code}')
|
||||
print(f'Access token: {json.loads(validation_response.text)["access_token"]}')
|
||||
access_token = json.loads(validation_response.text)["access_token"]
|
||||
|
||||
|
||||
#%%
|
||||
######## STEP 2: Create Order ########
|
||||
# url for post request
|
||||
order_url = "https://api-m.sandbox.paypal.com/v2/checkout/orders"
|
||||
order_url_prod = "https://api-m.paypal.com/v2/checkout/orders"
|
||||
|
||||
# payload for the request
|
||||
order_data = {
|
||||
"intent": "CAPTURE",
|
||||
"purchase_units": [
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"name": "AnyWay Credits",
|
||||
"description": "50 pack of credits",
|
||||
"quantity": 1,
|
||||
"unit_amount": {
|
||||
"currency_code": "CHF",
|
||||
"value": "1.50"
|
||||
}
|
||||
|
||||
}
|
||||
],
|
||||
"amount": {
|
||||
"currency_code": "CHF",
|
||||
"value": "1.50",
|
||||
"breakdown": {
|
||||
"item_total": {
|
||||
"currency_code": "CHF",
|
||||
"value": "1.50"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"application_context": {
|
||||
"return_url": "https://anydev.info",
|
||||
"cancel_url": "https://anydev.info"
|
||||
}
|
||||
}
|
||||
|
||||
order_response = requests.post(
|
||||
url=order_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"}, ## need access token here?
|
||||
json=order_data,
|
||||
auth=(username, password)
|
||||
)
|
||||
|
||||
# Send the redirect link to the user
|
||||
# print(order_response.json())
|
||||
for link_obj in order_response.json()['links']:
|
||||
if link_obj['rel'] == 'approve':
|
||||
forward_to_user_link = link_obj['href']
|
||||
print(f'Reponse status code: {order_response.status_code}')
|
||||
print(f'Follow this link to proceed to payment: {forward_to_user_link}')
|
||||
order_id = json.loads(order_response.text)["id"]
|
||||
|
||||
|
||||
#%%
|
||||
######## STEP 3: capture payment
|
||||
# url for post request
|
||||
capture_url = f"https://api-m.sandbox.paypal.com/v2/checkout/orders/{order_id}/capture"
|
||||
# capture_url_prod = f"https://api-m.paypal.com/v2/checkout/orders/{order_id}/capture"
|
||||
|
||||
capture_response = requests.post(
|
||||
url=capture_url,
|
||||
json={},
|
||||
auth=(username, password)
|
||||
)
|
||||
|
||||
# todo check status code + try except
|
||||
print(f'Reponse status code: {capture_response.status_code}')
|
||||
print(capture_response.text)
|
||||
# order_id = json.loads(response.text)["id"]
|
||||
@@ -5,6 +5,7 @@ from fastapi import HTTPException, status
|
||||
from supabase import create_client, Client, ClientOptions
|
||||
|
||||
from ..constants import PARAMETERS_DIR
|
||||
from ..configuration.environment import Environment
|
||||
|
||||
# Silence the supabase logger
|
||||
logging.getLogger("httpx").setLevel(logging.CRITICAL)
|
||||
@@ -18,11 +19,9 @@ class SupabaseClient:
|
||||
|
||||
def __init__(self):
|
||||
|
||||
with open(os.path.join(PARAMETERS_DIR, 'secrets.yaml')) as f:
|
||||
secrets = yaml.safe_load(f)
|
||||
self.SUPABASE_URL = secrets['SUPABASE_URL']
|
||||
self.SUPABASE_ADMIN_KEY = secrets['SUPABASE_ADMIN_KEY']
|
||||
self.SUPABASE_TEST_USER_ID = secrets['SUPABASE_TEST_USER_ID']
|
||||
self.SUPABASE_URL = Environment.supabase_url
|
||||
self.SUPABASE_ADMIN_KEY = Environment.supabase_admin_key
|
||||
self.SUPABASE_TEST_USER_ID = Environment.supabase_test_user_id
|
||||
|
||||
self.supabase = create_client(
|
||||
self.SUPABASE_URL,
|
||||
|
||||
53
backend/src/tests/test_payment.py
Normal file
53
backend/src/tests/test_payment.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Collection of tests to ensure correct implementation and track progress of paypal payments."""
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
|
||||
from ..main import app
|
||||
from ..supabase.supabase import SupabaseClient
|
||||
|
||||
|
||||
# Create a supabase client
|
||||
supabase = SupabaseClient()
|
||||
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
"""Client used to call the app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_nearby(client): # pylint: disable=redefined-outer-name
|
||||
"""
|
||||
Test n°1 : Verify handling of invalid input.
|
||||
|
||||
Args:
|
||||
client:
|
||||
request:
|
||||
"""
|
||||
response = client.post(
|
||||
url=f"/orders/new/",
|
||||
json={
|
||||
'user_id': supabase.SUPABASE_TEST_USER_ID,
|
||||
'basket': {
|
||||
{
|
||||
'id': '1873672819',
|
||||
'quantity': 1982
|
||||
},
|
||||
{
|
||||
'id': '9876789',
|
||||
'quantity': 1982
|
||||
}
|
||||
},
|
||||
'currency': 'CHF',
|
||||
'return_url_success': 'https://anydev.info',
|
||||
'return_url_failure': 'https://anydev.info'
|
||||
}
|
||||
)
|
||||
suggestions = response.json()
|
||||
|
||||
# checks :
|
||||
assert response.status_code == 200 # check for successful planning
|
||||
assert isinstance(suggestions, list) # check that the return type is a list
|
||||
assert len(suggestions) > 0
|
||||
@@ -1,14 +1,11 @@
|
||||
"""Collection of tests to ensure correct implementation and track progress."""
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
|
||||
from .test_utils import load_trip_landmarks, log_trip_details
|
||||
from ..supabase.supabase import SupabaseClient
|
||||
from ..structs.preferences import Preferences, Preference
|
||||
from ..constants import PARAMETERS_DIR
|
||||
from ..main import app
|
||||
|
||||
|
||||
@@ -59,8 +56,8 @@ def test_trip(client, request, sightseeing, shopping, nature, max_time_minute, s
|
||||
"preferences": prefs.model_dump(),
|
||||
"start": start_coords,
|
||||
"end": end_coords,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
landmarks = response.json()
|
||||
@@ -74,8 +71,8 @@ def test_trip(client, request, sightseeing, shopping, nature, max_time_minute, s
|
||||
"landmarks": landmarks,
|
||||
"start": start,
|
||||
"end": end,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Increment the user balance again
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
"""Helper methods for testing."""
|
||||
import time
|
||||
import logging
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ..cache import client as cache_client
|
||||
from ..structs.landmark import Landmark
|
||||
from ..structs.preferences import Preferences, Preference
|
||||
|
||||
|
||||
def landmarks_to_osmid(landmarks: list[Landmark]) -> list[int] :
|
||||
@@ -94,34 +91,3 @@ def log_trip_details(request, landmarks: list[Landmark], duration: int, target_d
|
||||
request.node.trip_details = trip_string
|
||||
request.node.trip_duration = str(duration) # result['total_time']
|
||||
request.node.target_duration = str(target_duration)
|
||||
|
||||
|
||||
|
||||
|
||||
def trip_params(
|
||||
sightseeing: int,
|
||||
shopping: int,
|
||||
nature: int,
|
||||
max_time_minute: int,
|
||||
start_coords: tuple[float, float] = None,
|
||||
end_coords: tuple[float, float] = None,
|
||||
):
|
||||
def decorator(test_func):
|
||||
@wraps(test_func)
|
||||
def wrapper(client, request):
|
||||
prefs = Preferences(
|
||||
sightseeing=Preference(type='sightseeing', score=sightseeing),
|
||||
shopping=Preference(type='shopping', score=shopping),
|
||||
nature=Preference(type='nature', score=nature),
|
||||
max_time_minute=max_time_minute,
|
||||
detour_tolerance_minute=0,
|
||||
)
|
||||
|
||||
start = start_coords
|
||||
end = end_coords
|
||||
|
||||
# Inject into test function
|
||||
return test_func(client, request, prefs, start, end)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
0
backend/src/trips/__init__.py
Normal file
0
backend/src/trips/__init__.py
Normal file
104
backend/src/trips/trips_router.py
Normal file
104
backend/src/trips/trips_router.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import logging
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from ..structs.landmark import Landmark
|
||||
from ..structs.linked_landmarks import LinkedLandmarks
|
||||
from ..structs.trip import Trip
|
||||
from ..landmarks.landmarks_manager import LandmarkManager
|
||||
from ..optimization.optimizer import Optimizer
|
||||
from ..optimization.refiner import Refiner
|
||||
from ..cache import client as cache_client
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
manager = LandmarkManager()
|
||||
optimizer = Optimizer()
|
||||
refiner = Refiner(optimizer=optimizer)
|
||||
|
||||
|
||||
# Initialize the API router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
#### For already existing trips/landmarks
|
||||
@router.get("/trip/{trip_uuid}")
|
||||
def get_trip(trip_uuid: str) -> Trip:
|
||||
"""
|
||||
Look-up the cache for a trip that has been previously generated using its identifier.
|
||||
|
||||
Args:
|
||||
trip_uuid (str) : unique identifier for a trip.
|
||||
|
||||
Returns:
|
||||
(Trip) : the corresponding trip.
|
||||
"""
|
||||
try:
|
||||
trip = cache_client.get(f"trip_{trip_uuid}")
|
||||
return trip
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to fetch trip with UUID {trip_uuid}: {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail="Trip not found") from exc
|
||||
|
||||
|
||||
# Fetch a landmark from memcached by its uuid
|
||||
@router.get("/landmark/{landmark_uuid}")
|
||||
def get_landmark(landmark_uuid: str) -> Landmark:
|
||||
"""
|
||||
Returns a Landmark from its unique identifier.
|
||||
|
||||
Args:
|
||||
landmark_uuid (str) : unique identifier for a Landmark.
|
||||
|
||||
Returns:
|
||||
(Landmark) : the corresponding Landmark.
|
||||
"""
|
||||
try:
|
||||
landmark = cache_client.get(f"landmark_{landmark_uuid}")
|
||||
return landmark
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to fetch landmark with UUID {landmark_uuid}: {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail="Landmark not found") from exc
|
||||
|
||||
|
||||
# Update the times between landmarks when removing an item from the list
|
||||
@router.post("/trip/recompute-time/{trip_uuid}/{removed_landmark_uuid}")
|
||||
def update_trip_time(trip_uuid: str, removed_landmark_uuid: str) -> Trip:
|
||||
"""
|
||||
Updates the reaching times of a given trip when removing a landmark.
|
||||
|
||||
Args:
|
||||
landmark_uuid (str) : unique identifier for a Landmark.
|
||||
|
||||
Returns:
|
||||
(Landmark) : the corresponding Landmark.
|
||||
"""
|
||||
# First, fetch the trip in the cache.
|
||||
try:
|
||||
trip = cache_client.get(f'trip_{trip_uuid}')
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to update trip with UUID {trip_uuid} (trip not found): {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail='Trip not found') from exc
|
||||
|
||||
landmarks = []
|
||||
next_uuid = trip.first_landmark_uuid
|
||||
|
||||
# Extract landmarks
|
||||
try :
|
||||
while next_uuid is not None:
|
||||
landmark = cache_client.get(f'landmark_{next_uuid}')
|
||||
# Filter out the removed landmark.
|
||||
if next_uuid != removed_landmark_uuid :
|
||||
landmarks.append(landmark)
|
||||
next_uuid = landmark.next_uuid # Prepare for the next iteration
|
||||
except KeyError as exc:
|
||||
logger.error(f"Failed to update trip with UUID {trip_uuid} : {str(exc)}")
|
||||
raise HTTPException(status_code=404, detail=f'landmark {next_uuid} not found') from exc
|
||||
|
||||
# Re-link every thing and compute times again
|
||||
linked_tour = LinkedLandmarks(landmarks)
|
||||
trip = Trip.from_linked_landmarks(linked_tour, cache_client)
|
||||
|
||||
return trip
|
||||
|
||||
15
backend/uv.lock
generated
15
backend/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.12"
|
||||
|
||||
[[package]]
|
||||
@@ -135,6 +135,7 @@ dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "charset-normalizer" },
|
||||
{ name = "click" },
|
||||
{ name = "dotenv" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "fastapi-cli" },
|
||||
{ name = "h11" },
|
||||
@@ -188,6 +189,7 @@ requires-dist = [
|
||||
{ name = "certifi", marker = "python_full_version >= '3.6'", specifier = "==2024.12.14" },
|
||||
{ name = "charset-normalizer", marker = "python_full_version >= '3.7'", specifier = "==3.4.1" },
|
||||
{ name = "click", marker = "python_full_version >= '3.7'", specifier = "==8.1.8" },
|
||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||
{ name = "fastapi", marker = "python_full_version >= '3.8'", specifier = "==0.115.7" },
|
||||
{ name = "fastapi-cli", marker = "python_full_version >= '3.8'", specifier = "==0.0.7" },
|
||||
{ name = "h11", marker = "python_full_version >= '3.7'", specifier = "==0.14.0" },
|
||||
@@ -414,6 +416,17 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/c3/253a89ee03fc9b9682f1541728eb66db7db22148cd94f89ab22528cd1e1b/deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a", size = 11178, upload-time = "2020-04-20T14:23:36.581Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dotenv"
|
||||
version = "0.9.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "python-dotenv" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "executing"
|
||||
version = "2.2.0"
|
||||
|
||||
Reference in New Issue
Block a user